Unverified Commit 5aedcab3 authored by moto's avatar moto Committed by GitHub
Browse files

Tweak test name by appending factory function name (#1780)

Apply tweak around the test names so that it's easier to see which tests are failing.

Before: `test_import_finetuned_model_2`
After:  `test_import_finetuned_model_2_wav2vec2_large_lv60k`
parent 78b08c26
......@@ -23,6 +23,10 @@ def _load_config(*paths):
return json.load(file_)
def _name_func(testcase_func, i, param):
return f'{testcase_func.__name__}_{i}_{param[0][1].__name__}'
# Pretrined (not fine-tuned) models
BASE = _load_config('wav2vec_small')
LARGE = _load_config('libri960_big')
......@@ -35,18 +39,18 @@ LARGE_LV60K_960H = _load_config('wav2vec_large_lv60k_960h')
LARGE_LV60K_SELF_960H = _load_config('wav2vec_large_lv60k_self_960h')
# Config and corresponding factory functions
PRETRAINED_CONFIGS = [
PRETRAINED_CONFIGS = parameterized.expand([
(BASE, wav2vec2_base),
(LARGE, wav2vec2_large),
(LARGE_LV60K, wav2vec2_large_lv60k),
(XLSR_53_56K, wav2vec2_large_lv60k),
]
FINETUNED_CONFIGS = [
], name_func=_name_func)
FINETUNED_CONFIGS = parameterized.expand([
(BASE_960H, wav2vec2_base),
(LARGE_960H, wav2vec2_large),
(LARGE_LV60K_960H, wav2vec2_large_lv60k),
(LARGE_LV60K_SELF_960H, wav2vec2_large_lv60k),
]
], name_func=_name_func)
@skipIfNoModule('fairseq')
......@@ -75,9 +79,10 @@ class TestFairseqIntegration(TorchaudioTestCase):
return Wav2VecEncoder(Wav2Vec2CtcConfig(**config), num_out)
if config['_name'] == 'wav2vec2':
return Wav2Vec2Model(Wav2Vec2Config(**config))
raise ValueError(f'Unexpected configuration: {config["_name"]}')
@parameterized.expand([conf[:1] for conf in PRETRAINED_CONFIGS])
def test_import_pretrained_model(self, config):
@PRETRAINED_CONFIGS
def test_import_pretrained_model(self, config, _):
"""Pretrained wav2vec2 models from fairseq can be imported and yields the same results"""
num_out = 28
batch_size, num_frames = 3, 1024
......@@ -91,7 +96,7 @@ class TestFairseqIntegration(TorchaudioTestCase):
for i, (ref, _) in enumerate(refs['layer_results']):
self.assertEqual(hyp[i], ref.transpose(0, 1))
@parameterized.expand(PRETRAINED_CONFIGS)
@PRETRAINED_CONFIGS
def test_recreate_pretrained_model(self, config, factory_func):
"""Imported pretrained models can be recreated via a factory function without fairseq."""
num_out = 28
......@@ -117,8 +122,8 @@ class TestFairseqIntegration(TorchaudioTestCase):
self.assertEqual(ref, hyp)
self.assertEqual(ref_lengths, hyp_lengths)
@parameterized.expand([conf[:1] for conf in FINETUNED_CONFIGS])
def test_import_finetuned_model(self, config):
@FINETUNED_CONFIGS
def test_import_finetuned_model(self, config, _):
"""Fintuned wav2vec2 models from fairseq can be imported and yields the same results"""
num_out = 28
batch_size, num_frames = 3, 1024
......@@ -140,7 +145,7 @@ class TestFairseqIntegration(TorchaudioTestCase):
for i, l in enumerate(output_lengths):
self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...])
@parameterized.expand(FINETUNED_CONFIGS)
@FINETUNED_CONFIGS
def test_recreate_finetuned_model(self, config, factory_func):
"""Imported finetuned models can be recreated via a factory function without fairseq."""
num_out = 28
......
......@@ -21,6 +21,10 @@ def _load_config(*paths):
return json.load(file_)
def _name_func(testcase_func, i, param):
return f"{testcase_func.__name__}_{i}_{param[0][1].__name__}"
# Pretrained
HF_BASE = _load_config('facebook', 'wav2vec2-base')
HF_LARGE = _load_config('facebook', 'wav2vec2-large')
......@@ -35,7 +39,7 @@ HF_LARGE_LV60_SELF_960H = _load_config('facebook', 'wav2vec2-large-960h-lv60-sel
HF_LARGE_XLSR_DE = _load_config('facebook', 'wav2vec2-large-xlsr-53-german')
# Config and corresponding factory functions
HF_CONFIGS = [
HF_CONFIGS = parameterized.expand([
# pretrained
(HF_BASE, wav2vec2_base),
(HF_LARGE, wav2vec2_large),
......@@ -48,7 +52,7 @@ HF_CONFIGS = [
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k),
(HF_LARGE_LV60_SELF_960H, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_DE, wav2vec2_large_lv60k),
]
], name_func=_name_func)
@skipIfNoModule('transformers')
......@@ -72,8 +76,8 @@ class TestHFIntegration(TorchaudioTestCase):
)
return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
@parameterized.expand([cfg[:1] for cfg in HF_CONFIGS])
def test_import(self, config):
@HF_CONFIGS
def test_import(self, config, _):
"""wav2vec2 models from HF transformers can be imported and yields the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
......@@ -138,7 +142,7 @@ class TestHFIntegration(TorchaudioTestCase):
for i, l in enumerate(output_lengths):
self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...])
@parameterized.expand(HF_CONFIGS)
@HF_CONFIGS
def test_recreate(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
......
......@@ -15,8 +15,8 @@ from torchaudio_unittest.common_utils import (
from parameterized import parameterized
def _name_func(testcase_func, _, param):
return f"{testcase_func.__name__}_{param[0][0].__name__}"
def _name_func(testcase_func, i, param):
return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}"
factory_funcs = parameterized.expand([
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment