Unverified Commit 19d8f1c2 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor integration test (#1922)

- Make the test support other languages
- Fetch tetst asset on-the-fly
parent 716aa416
import torch import torch
from torchaudio_unittest.common_utils import get_asset_path import requests
import pytest import pytest
...@@ -32,6 +32,22 @@ def ctc_decoder(): ...@@ -32,6 +32,22 @@ def ctc_decoder():
return GreedyCTCDecoder return GreedyCTCDecoder
_FILES = {
'en': 'Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac',
}
@pytest.fixture @pytest.fixture
def sample_speech_16000_en(): def sample_speech(tmp_path, lang):
return get_asset_path('Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac') if lang not in _FILES:
raise NotImplementedError(f'Unexpected lang: {lang}')
filename = _FILES[lang]
path = tmp_path.parent / filename
if not path.exists():
url = f'https://download.pytorch.org/torchaudio/test-assets/{filename}'
print(f'downloading from {url}')
with open(path, 'wb') as file:
with requests.get(url) as resp:
resp.raise_for_status()
file.write(resp.content)
return path
...@@ -40,30 +40,31 @@ def test_pretraining_models(bundle): ...@@ -40,30 +40,31 @@ def test_pretraining_models(bundle):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"bundle,expected", "bundle,lang,expected",
[ [
(WAV2VEC2_ASR_BASE_10M, 'I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_BASE_10M, 'en', 'I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_BASE_100H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_BASE_100H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_BASE_960H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_BASE_960H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_10M, 'I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_LARGE_10M, 'en', 'I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_100H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_LARGE_100H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_960H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_LARGE_960H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_LV60K_10M, 'I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|'), (WAV2VEC2_ASR_LARGE_LV60K_10M, 'en', 'I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_LV60K_100H, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_LARGE_LV60K_100H, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_LV60K_960H, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_LARGE_LV60K_960H, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(HUBERT_ASR_LARGE, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (HUBERT_ASR_LARGE, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(HUBERT_ASR_XLARGE, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|') (HUBERT_ASR_XLARGE, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
] ]
) )
def test_finetune_asr_model( def test_finetune_asr_model(
bundle, bundle,
lang,
expected, expected,
sample_speech_16000_en, sample_speech,
ctc_decoder, ctc_decoder,
): ):
"""Smoke test of downloading weights for fine-tuning models and simple transcription""" """Smoke test of downloading weights for fine-tuning models and simple transcription"""
model = bundle.get_model().eval() model = bundle.get_model().eval()
waveform, sample_rate = torchaudio.load(sample_speech_16000_en) waveform, sample_rate = torchaudio.load(sample_speech)
emission, _ = model(waveform) emission, _ = model(waveform)
decoder = ctc_decoder(bundle.get_labels()) decoder = ctc_decoder(bundle.get_labels())
result = decoder(emission[0]) result = decoder(emission[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment