Unverified Commit 64551a69 authored by Jcaw's avatar Jcaw Committed by GitHub
Browse files

Apply misc updates to `functional/batch_consistency_test.py` (#1341)

* Parameterize `test_sliding_window_cmn`

* Extract test naming function

* Pass a spectrogram to `F.sliding_window_cmn`

* Set manual seed for remaining rand calls in suite
parent 9a96fb7e
...@@ -9,6 +9,13 @@ import torchaudio.functional as F ...@@ -9,6 +9,13 @@ import torchaudio.functional as F
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
def _name_from_args(func, _, params):
"""Return a parameterized test name, based on parameter values."""
return "{}_{}".format(
func.__name__,
"_".join(str(arg) for arg in params.args))
@parameterized_class([ @parameterized_class([
# Single-item batch isolates problems that come purely from adding a # Single-item batch isolates problems that come purely from adding a
# dimension (rather than processing multiple items) # dimension (rather than processing multiple items)
...@@ -58,7 +65,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -58,7 +65,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
[8000, 16000, 44100], [8000, 16000, 44100],
[1, 2], [1, 2],
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') )), name_func=_name_from_args)
def test_detect_pitch_frequency(self, sample_rate, n_channels): def test_detect_pitch_frequency(self, sample_rate, n_channels):
# Use different frequencies to ensure each item in the batch returns a # Use different frequencies to ensure each item in the batch returns a
# different answer. # different answer.
...@@ -180,16 +187,16 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -180,16 +187,16 @@ class TestFunctional(common_utils.TorchaudioTestCase):
sample_rate = 44100 sample_rate = 44100
self.assert_batch_consistency(F.flanger, waveforms, sample_rate) self.assert_batch_consistency(F.flanger, waveforms, sample_rate)
def test_sliding_window_cmn(self): @parameterized.expand(list(itertools.product(
waveforms = torch.randn(self.batch_size, 2, 1024) - 0.5 [True, False], # center
self.assert_batch_consistency( [True, False], # norm_vars
F.sliding_window_cmn, waveforms, center=True, norm_vars=True) )), name_func=_name_from_args)
self.assert_batch_consistency( def test_sliding_window_cmn(self, center, norm_vars):
F.sliding_window_cmn, waveforms, center=True, norm_vars=False) torch.manual_seed(0)
self.assert_batch_consistency( spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200
F.sliding_window_cmn, waveforms, center=False, norm_vars=True)
self.assert_batch_consistency( self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=False, norm_vars=False) F.sliding_window_cmn, spectrogram, center=center,
norm_vars=norm_vars)
def test_vad_from_file(self): def test_vad_from_file(self):
filepath = common_utils.get_asset_path("vad-go-stereo-44100.wav") filepath = common_utils.get_asset_path("vad-go-stereo-44100.wav")
...@@ -202,6 +209,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -202,6 +209,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def test_vad_different_items(self): def test_vad_different_items(self):
"""Separate test to ensure VAD consistency with differing items.""" """Separate test to ensure VAD consistency with differing items."""
sample_rate = 44100 sample_rate = 44100
torch.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency( self.assert_batch_consistency(
F.vad, waveforms, sample_rate=sample_rate) F.vad, waveforms, sample_rate=sample_rate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment