Unverified Commit 64551a69 authored by Jcaw's avatar Jcaw Committed by GitHub
Browse files

Apply misc updates to `functional/batch_consistency_test.py` (#1341)

* Parameterize `test_sliding_window_cmn`

* Extract test naming function

* Pass a spectrogram to `F.sliding_window_cmn`

* Set manual seed for remaining rand calls in suite
parent 9a96fb7e
......@@ -9,6 +9,13 @@ import torchaudio.functional as F
from torchaudio_unittest import common_utils
def _name_from_args(func, _, params):
"""Return a parameterized test name, based on parameter values."""
return "{}_{}".format(
func.__name__,
"_".join(str(arg) for arg in params.args))
@parameterized_class([
# Single-item batch isolates problems that come purely from adding a
# dimension (rather than processing multiple items)
......@@ -58,7 +65,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[8000, 16000, 44100],
[1, 2],
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}')
)), name_func=_name_from_args)
def test_detect_pitch_frequency(self, sample_rate, n_channels):
# Use different frequencies to ensure each item in the batch returns a
# different answer.
......@@ -180,16 +187,16 @@ class TestFunctional(common_utils.TorchaudioTestCase):
sample_rate = 44100
self.assert_batch_consistency(F.flanger, waveforms, sample_rate)
def test_sliding_window_cmn(self):
waveforms = torch.randn(self.batch_size, 2, 1024) - 0.5
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=True, norm_vars=True)
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=True, norm_vars=False)
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=False, norm_vars=True)
@parameterized.expand(list(itertools.product(
[True, False], # center
[True, False], # norm_vars
)), name_func=_name_from_args)
def test_sliding_window_cmn(self, center, norm_vars):
torch.manual_seed(0)
spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=False, norm_vars=False)
F.sliding_window_cmn, spectrogram, center=center,
norm_vars=norm_vars)
def test_vad_from_file(self):
filepath = common_utils.get_asset_path("vad-go-stereo-44100.wav")
......@@ -202,6 +209,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def test_vad_different_items(self):
"""Separate test to ensure VAD consistency with differing items."""
sample_rate = 44100
torch.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(
F.vad, waveforms, sample_rate=sample_rate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment