Unverified Commit 2a1aaff3 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Fix invalid test names generation (#1374)

parent 0ea475af
...@@ -12,6 +12,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -12,6 +12,7 @@ from torchaudio_unittest.common_utils import (
skipIfNoModule, skipIfNoModule,
get_wav_data, get_wav_data,
save_wav, save_wav,
nested_params,
) )
from torchaudio_unittest.backend.common import ( from torchaudio_unittest.backend.common import (
get_bits_per_sample, get_bits_per_sample,
...@@ -77,7 +78,14 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -77,7 +78,14 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.bits_per_sample == 0 assert info.bits_per_sample == 0
assert info.encoding == "VORBIS" assert info.encoding == "VORBIS"
@parameterize([8000, 16000], [1, 2], [('PCM_24', 24), ('PCM_32', 32)]) @nested_params(
[8000, 16000],
[1, 2],
[
('PCM_24', 24),
('PCM_32', 32)
],
)
@skipIfFormatNotSupported("NIST") @skipIfFormatNotSupported("NIST")
def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth): def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
"""`soundfile_backend.info` can check sph file correctly""" """`soundfile_backend.info` can check sph file correctly"""
......
...@@ -10,6 +10,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -10,6 +10,7 @@ from torchaudio_unittest.common_utils import (
skipIfNoModule, skipIfNoModule,
get_wav_data, get_wav_data,
load_wav, load_wav,
nested_params,
) )
from .common import ( from .common import (
fetch_wav_subtype, fetch_wav_subtype,
...@@ -22,8 +23,11 @@ if _mod_utils.is_module_available("soundfile"): ...@@ -22,8 +23,11 @@ if _mod_utils.is_module_available("soundfile"):
class MockedSaveTest(PytorchTestCase): class MockedSaveTest(PytorchTestCase):
@parameterize( @nested_params(
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2], [False, True], ["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
[False, True],
[ [
(None, None), (None, None),
('PCM_U', None), ('PCM_U', None),
...@@ -101,7 +105,7 @@ class MockedSaveTest(PytorchTestCase): ...@@ -101,7 +105,7 @@ class MockedSaveTest(PytorchTestCase):
assert args["format"] is None assert args["format"] is None
self.assertEqual(args["data"], expected_data) self.assertEqual(args["data"], expected_data)
@parameterize( @nested_params(
["sph", "nist", "nis"], ["sph", "nist", "nis"],
["int32", "int16"], ["int32", "int16"],
[8000, 16000], [8000, 16000],
...@@ -240,7 +244,7 @@ class TestSave(SaveTestBase): ...@@ -240,7 +244,7 @@ class TestSave(SaveTestBase):
class TestSaveParams(TempDirMixin, PytorchTestCase): class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `soundfile_backend.save`""" """Test the correctness of optional parameters of `soundfile_backend.save`"""
@parameterize([(True,), (False,)]) @parameterize([True, False])
def test_channels_first(self, channels_first): def test_channels_first(self, channels_first):
"""channels_first swaps axes""" """channels_first swaps axes"""
path = self.get_temp_path("data.wav") path = self.get_temp_path("data.wav")
......
import io import io
import unittest import unittest
from itertools import product
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from parameterized import parameterized from parameterized import parameterized
...@@ -15,6 +14,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -15,6 +14,7 @@ from torchaudio_unittest.common_utils import (
load_wav, load_wav,
save_wav, save_wav,
sox_utils, sox_utils,
nested_params,
) )
from .common import ( from .common import (
name_func, name_func,
...@@ -140,22 +140,6 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -140,22 +140,6 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
self.assertEqual(found, expected) self.assertEqual(found, expected)
def nested_params(*params):
def _name_func(func, _, params):
strs = []
for arg in params.args:
if isinstance(arg, tuple):
strs.append("_".join(str(a) for a in arg))
else:
strs.append(str(arg))
return f'{func.__name__}_{"_".join(strs)}'
return parameterized.expand(
list(product(*params)),
name_func=_name_func
)
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoSox @skipIfNoSox
class SaveTest(SaveTestBase): class SaveTest(SaveTestBase):
......
...@@ -27,9 +27,11 @@ from .wav_utils import ( ...@@ -27,9 +27,11 @@ from .wav_utils import (
) )
from .parameterized_utils import ( from .parameterized_utils import (
load_params, load_params,
nested_params
) )
__all__ = ['get_asset_path', 'get_whitenoise', 'get_sinusoid', 'set_audio_backend', __all__ = ['get_asset_path', 'get_whitenoise', 'get_sinusoid', 'set_audio_backend',
'TempDirMixin', 'HttpServerMixin', 'TestBaseMixin', 'PytorchTestCase', 'TorchaudioTestCase', 'TempDirMixin', 'HttpServerMixin', 'TestBaseMixin', 'PytorchTestCase', 'TorchaudioTestCase',
'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoKaldi', 'skipIfNoSox', 'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoKaldi', 'skipIfNoSox',
'skipIfNoSoxBackend', 'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav', 'load_params'] 'skipIfNoSoxBackend', 'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav', 'load_params',
'nested_params']
import json import json
from itertools import product
from parameterized import param from parameterized import param, parameterized
from .data_utils import get_asset_path from .data_utils import get_asset_path
...@@ -8,3 +9,19 @@ from .data_utils import get_asset_path ...@@ -8,3 +9,19 @@ from .data_utils import get_asset_path
def load_params(*paths): def load_params(*paths):
with open(get_asset_path(*paths), 'r') as file: with open(get_asset_path(*paths), 'r') as file:
return [param(json.loads(line)) for line in file] return [param(json.loads(line)) for line in file]
def nested_params(*params):
def _name_func(func, _, params):
strs = []
for arg in params.args:
if isinstance(arg, tuple):
strs.append("_".join(str(a) for a in arg))
else:
strs.append(str(arg))
return f'{func.__name__}_{"_".join(strs)}'
return parameterized.expand(
list(product(*params)),
name_func=_name_func
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment