Unverified Commit 1e7d8d20 authored by Krishna Kalyan's avatar Krishna Kalyan Committed by GitHub
Browse files

Replace pytest's paremeterization with parameterized (#1157)



Also replaces `assert_allclose` with `assertEqual`.
Co-authored-by: default avatarkrishnakalyan3 <skalyan@cloudera.com>
parent df48ba36
...@@ -5,6 +5,7 @@ import torchaudio ...@@ -5,6 +5,7 @@ import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized from parameterized import parameterized
import pytest import pytest
import itertools
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .functional_impl import Lfilter, Spectrogram from .functional_impl import Lfilter, Spectrogram
...@@ -53,7 +54,7 @@ class TestComputeDeltas(common_utils.TorchaudioTestCase): ...@@ -53,7 +54,7 @@ class TestComputeDeltas(common_utils.TorchaudioTestCase):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]]) specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]]) expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3) computed = F.compute_deltas(specgram, win_length=3)
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
def test_two_channels(self): def test_two_channels(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0], specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
...@@ -61,7 +62,7 @@ class TestComputeDeltas(common_utils.TorchaudioTestCase): ...@@ -61,7 +62,7 @@ class TestComputeDeltas(common_utils.TorchaudioTestCase):
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]]) [0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3) computed = F.compute_deltas(specgram, win_length=3)
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase): class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
...@@ -97,13 +98,13 @@ class TestDB_to_amplitude(common_utils.TorchaudioTestCase): ...@@ -97,13 +98,13 @@ class TestDB_to_amplitude(common_utils.TorchaudioTestCase):
db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None) db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power) x2 = F.DB_to_amplitude(db, ref, power)
torch.testing.assert_allclose(x2, torch.abs(x), atol=5e-5, rtol=1e-5) self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
# Spectrogram amplitude -> DB -> amplitude # Spectrogram amplitude -> DB -> amplitude
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None) db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power) x2 = F.DB_to_amplitude(db, ref, power)
torch.testing.assert_allclose(x2, spec, atol=5e-5, rtol=1e-5) self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)
# Waveform power -> DB -> power # Waveform power -> DB -> power
multiplier = 10. multiplier = 10.
...@@ -112,61 +113,66 @@ class TestDB_to_amplitude(common_utils.TorchaudioTestCase): ...@@ -112,61 +113,66 @@ class TestDB_to_amplitude(common_utils.TorchaudioTestCase):
db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None) db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power) x2 = F.DB_to_amplitude(db, ref, power)
torch.testing.assert_allclose(x2, torch.abs(x), atol=5e-5, rtol=1e-5) self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
# Spectrogram power -> DB -> power # Spectrogram power -> DB -> power
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None) db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power) x2 = F.DB_to_amplitude(db, ref, power)
torch.testing.assert_allclose(x2, spec, atol=5e-5, rtol=1e-5) self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)
@pytest.mark.parametrize('complex_tensor', [ class TestComplexNorm(common_utils.TorchaudioTestCase):
torch.randn(1, 2, 1025, 400, 2), @parameterized.expand(list(itertools.product(
torch.randn(1025, 400, 2) [(1, 2, 1025, 400, 2), (1025, 400, 2)],
]) [1, 2, 0.7]
@pytest.mark.parametrize('power', [1, 2, 0.7]) )))
def test_complex_norm(complex_tensor, power): def test_complex_norm(self, shape, power):
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2) torch.random.manual_seed(42)
norm_tensor = F.complex_norm(complex_tensor, power) complex_tensor = torch.randn(*shape)
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
norm_tensor = F.complex_norm(complex_tensor, power)
self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
torch.testing.assert_allclose(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
class TestMaskAlongAxis(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[(2, 1025, 400), (1, 201, 100)],
[100],
[0., 30.],
[1, 2]
)))
def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
torch.random.manual_seed(42)
specgram = torch.randn(*shape)
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
@pytest.mark.parametrize('specgram', [ other_axis = 1 if axis == 2 else 2
torch.randn(2, 1025, 400),
torch.randn(1, 201, 100)
])
@pytest.mark.parametrize('mask_param', [100])
@pytest.mark.parametrize('mask_value', [0., 30.])
@pytest.mark.parametrize('axis', [1, 2])
def test_mask_along_axis(specgram, mask_param, mask_value, axis):
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis) masked_columns = (mask_specgram == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
num_masked_columns //= mask_specgram.size(0)
other_axis = 1 if axis == 2 else 2 assert mask_specgram.size() == specgram.size()
assert num_masked_columns < mask_param
masked_columns = (mask_specgram == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
num_masked_columns //= mask_specgram.size(0)
assert mask_specgram.size() == specgram.size() class TestMaskAlongAxisIID(common_utils.TorchaudioTestCase):
assert num_masked_columns < mask_param @parameterized.expand(list(itertools.product(
[100],
[0., 30.],
[2, 3]
)))
def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
torch.random.manual_seed(42)
specgrams = torch.randn(4, 2, 1025, 400)
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
@pytest.mark.parametrize('mask_param', [100]) other_axis = 2 if axis == 3 else 3
@pytest.mark.parametrize('mask_value', [0., 30.])
@pytest.mark.parametrize('axis', [2, 3])
def test_mask_along_axis_iid(mask_param, mask_value, axis):
torch.random.manual_seed(42)
specgrams = torch.randn(4, 2, 1025, 400)
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) masked_columns = (mask_specgrams == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)
other_axis = 2 if axis == 3 else 3 assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
masked_columns = (mask_specgrams == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)
assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
import torchaudio import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
from torchaudio._internal.module_utils import is_module_available from torchaudio._internal.module_utils import is_module_available
from parameterized import parameterized
import itertools
LIBROSA_AVAILABLE = is_module_available('librosa') LIBROSA_AVAILABLE = is_module_available('librosa')
...@@ -111,42 +113,49 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -111,42 +113,49 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5) self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
@pytest.mark.parametrize('complex_specgrams', [
torch.randn(2, 1025, 400, 2)
])
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
@pytest.mark.parametrize('hop_length', [256])
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
def test_phase_vocoder(complex_specgrams, rate, hop_length): class TestPhaseVocoder(common_utils.TorchaudioTestCase):
# Due to cummulative sum, numerical error in using torch.float32 will @parameterized.expand(list(itertools.product(
# result in bottom right values of the stretched sectrogram to not [(2, 1025, 400, 2)],
# match with librosa. [0.5, 1.01, 1.3],
[256]
complex_specgrams = complex_specgrams.type(torch.float64) )))
phase_advance = torch.linspace(0, np.pi * hop_length, complex_specgrams.shape[-3], dtype=torch.float64)[..., None] def test_phase_vocoder(self, shape, rate, hop_length):
# Due to cummulative sum, numerical error in using torch.float32 will
complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance) # result in bottom right values of the stretched sectrogram to not
# match with librosa.
# == Test shape torch.random.manual_seed(42)
expected_size = list(complex_specgrams.size()) complex_specgrams = torch.randn(*shape)
expected_size[-2] = int(np.ceil(expected_size[-2] / rate)) complex_specgrams = complex_specgrams.type(torch.float64)
phase_advance = torch.linspace(
assert complex_specgrams.dim() == complex_specgrams_stretch.dim() 0,
assert complex_specgrams_stretch.size() == torch.Size(expected_size) np.pi * hop_length,
complex_specgrams.shape[-3],
# == Test values dtype=torch.float64)[..., None]
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
mono_complex_specgram = complex_specgrams[index].numpy() complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j # == Test shape
expected_complex_stretch = librosa.phase_vocoder(mono_complex_specgram, expected_size = list(complex_specgrams.size())
rate=rate, expected_size[-2] = int(np.ceil(expected_size[-2] / rate))
hop_length=hop_length)
assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
complex_stretch = complex_specgrams_stretch[index].numpy() assert complex_specgrams_stretch.size() == torch.Size(expected_size)
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]
# == Test values
assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5) index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
mono_complex_specgram = complex_specgrams[index].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
expected_complex_stretch = librosa.phase_vocoder(
mono_complex_specgram,
rate=rate,
hop_length=hop_length)
complex_stretch = complex_specgrams_stretch[index].numpy()
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]
self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5)
def _load_audio_asset(*asset_paths, **kwargs): def _load_audio_asset(*asset_paths, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment