Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
19d8f1c2
Unverified
Commit
19d8f1c2
authored
Oct 22, 2021
by
moto
Committed by
GitHub
Oct 22, 2021
Browse files
Refactor integration test (#1922)
- Make the test support other languages - Fetch tetst asset on-the-fly
parent
716aa416
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
17 deletions
+34
-17
test/integration_tests/conftest.py
test/integration_tests/conftest.py
+19
-3
test/integration_tests/wav2vec2_pipeline_test.py
test/integration_tests/wav2vec2_pipeline_test.py
+15
-14
test/torchaudio_unittest/assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac
...t/assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac
+0
-0
No files found.
test/integration_tests/conftest.py
View file @
19d8f1c2
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
get_asset_path
import
requests
import
pytest
import
pytest
...
@@ -32,6 +32,22 @@ def ctc_decoder():
...
@@ -32,6 +32,22 @@ def ctc_decoder():
return
GreedyCTCDecoder
return
GreedyCTCDecoder
_FILES
=
{
'en'
:
'Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac'
,
}
@
pytest
.
fixture
@
pytest
.
fixture
def
sample_speech_16000_en
():
def
sample_speech
(
tmp_path
,
lang
):
return
get_asset_path
(
'Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac'
)
if
lang
not
in
_FILES
:
raise
NotImplementedError
(
f
'Unexpected lang:
{
lang
}
'
)
filename
=
_FILES
[
lang
]
path
=
tmp_path
.
parent
/
filename
if
not
path
.
exists
():
url
=
f
'https://download.pytorch.org/torchaudio/test-assets/
{
filename
}
'
print
(
f
'downloading from
{
url
}
'
)
with
open
(
path
,
'wb'
)
as
file
:
with
requests
.
get
(
url
)
as
resp
:
resp
.
raise_for_status
()
file
.
write
(
resp
.
content
)
return
path
test/integration_tests/wav2vec2_pipeline_test.py
View file @
19d8f1c2
...
@@ -40,30 +40,31 @@ def test_pretraining_models(bundle):
...
@@ -40,30 +40,31 @@ def test_pretraining_models(bundle):
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"bundle,expected"
,
"bundle,
lang,
expected"
,
[
[
(
WAV2VEC2_ASR_BASE_10M
,
'I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_BASE_10M
,
'en'
,
'I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_BASE_100H
,
'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_BASE_100H
,
'en'
,
'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_BASE_960H
,
'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_BASE_960H
,
'en'
,
'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_10M
,
'I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_10M
,
'en'
,
'I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_100H
,
'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_100H
,
'en'
,
'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_960H
,
'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_960H
,
'en'
,
'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_LV60K_10M
,
'I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_LV60K_10M
,
'en'
,
'I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_LV60K_100H
,
'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_LV60K_100H
,
'en'
,
'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_LV60K_960H
,
'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
WAV2VEC2_ASR_LARGE_LV60K_960H
,
'en'
,
'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
HUBERT_ASR_LARGE
,
'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
HUBERT_ASR_LARGE
,
'en'
,
'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
),
(
HUBERT_ASR_XLARGE
,
'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
)
(
HUBERT_ASR_XLARGE
,
'en'
,
'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'
)
,
]
]
)
)
def
test_finetune_asr_model
(
def
test_finetune_asr_model
(
bundle
,
bundle
,
lang
,
expected
,
expected
,
sample_speech
_16000_en
,
sample_speech
,
ctc_decoder
,
ctc_decoder
,
):
):
"""Smoke test of downloading weights for fine-tuning models and simple transcription"""
"""Smoke test of downloading weights for fine-tuning models and simple transcription"""
model
=
bundle
.
get_model
().
eval
()
model
=
bundle
.
get_model
().
eval
()
waveform
,
sample_rate
=
torchaudio
.
load
(
sample_speech
_16000_en
)
waveform
,
sample_rate
=
torchaudio
.
load
(
sample_speech
)
emission
,
_
=
model
(
waveform
)
emission
,
_
=
model
(
waveform
)
decoder
=
ctc_decoder
(
bundle
.
get_labels
())
decoder
=
ctc_decoder
(
bundle
.
get_labels
())
result
=
decoder
(
emission
[
0
])
result
=
decoder
(
emission
[
0
])
...
...
test/torchaudio_unittest/assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac
deleted
100644 → 0
View file @
716aa416
File deleted
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment