Unverified Commit b8187b4c authored by moto's avatar moto Committed by GitHub
Browse files

Adopt PyTorch's test util to batch consistency test (#643)

parent 44af0dea
"""Test numerical consistency among single input and batched input.""" """Test numerical consistency among single input and batched input."""
import unittest import unittest
import platform
import torch import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
import common_utils import common_utils
def _test_batch_consistency(functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs): class TestFunctional(TestCase):
# run then batch the result
torch.random.manual_seed(seed)
expected = functional(tensor.clone(), *args, **kwargs)
expected = expected.repeat([batch_size] + [1] * expected.dim())
# batch the input and run
torch.random.manual_seed(seed)
pattern = [batch_size] + [1] * tensor.dim()
computed = functional(tensor.repeat(pattern), *args, **kwargs)
torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol)
def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
_test_batch_consistency(functional, tensor, *args, batch_size=1, atol=atol, rtol=rtol, seed=seed, **kwargs)
_test_batch_consistency(functional, tensor, *args, batch_size=3, atol=atol, rtol=rtol, seed=seed, **kwargs)
class TestFunctional(unittest.TestCase):
"""Test functions defined in `functional` module""" """Test functions defined in `functional` module"""
def assert_batch_consistency(
self, functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
# run then batch the result
torch.random.manual_seed(seed)
expected = functional(tensor.clone(), *args, **kwargs)
expected = expected.repeat([batch_size] + [1] * expected.dim())
# batch the input and run
torch.random.manual_seed(seed)
pattern = [batch_size] + [1] * tensor.dim()
computed = functional(tensor.repeat(pattern), *args, **kwargs)
self.assertEqual(computed, expected, rtol=rtol, atol=atol)
def assert_batch_consistencies(
self, functional, tensor, *args, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
self.assert_batch_consistency(
functional, tensor, *args, batch_size=1, atol=atol, rtol=rtol, seed=seed, **kwargs)
self.assert_batch_consistency(
functional, tensor, *args, batch_size=3, atol=atol, rtol=rtol, seed=seed, **kwargs)
def test_griffinlim(self): def test_griffinlim(self):
n_fft = 400 n_fft = 400
ws = 400 ws = 400
...@@ -41,7 +43,7 @@ class TestFunctional(unittest.TestCase): ...@@ -41,7 +43,7 @@ class TestFunctional(unittest.TestCase):
n_iter = 32 n_iter = 32
length = 1000 length = 1000
tensor = torch.rand((1, 201, 6)) tensor = torch.rand((1, 201, 6))
_test_batch( self.assert_batch_consistencies(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5 F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
) )
...@@ -55,7 +57,7 @@ class TestFunctional(unittest.TestCase): ...@@ -55,7 +57,7 @@ class TestFunctional(unittest.TestCase):
for filename in filenames: for filename in filenames:
filepath = common_utils.get_asset_path(filename) filepath = common_utils.get_asset_path(filename)
waveform, sample_rate = torchaudio.load(filepath) waveform, sample_rate = torchaudio.load(filepath)
_test_batch(F.detect_pitch_frequency, waveform, sample_rate) self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate)
def test_istft(self): def test_istft(self):
stft = torch.tensor([ stft = torch.tensor([
...@@ -63,39 +65,39 @@ class TestFunctional(unittest.TestCase): ...@@ -63,39 +65,39 @@ class TestFunctional(unittest.TestCase):
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]] [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
]) ])
_test_batch(F.istft, stft, n_fft=4, length=4) self.assert_batch_consistencies(F.istft, stft, n_fft=4, length=4)
def test_contrast(self): def test_contrast(self):
waveform = torch.rand(2, 100) - 0.5 waveform = torch.rand(2, 100) - 0.5
_test_batch(F.contrast, waveform, enhancement_amount=80.) self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.)
def test_dcshift(self): def test_dcshift(self):
waveform = torch.rand(2, 100) - 0.5 waveform = torch.rand(2, 100) - 0.5
_test_batch(F.dcshift, waveform, shift=0.5, limiter_gain=0.05) self.assert_batch_consistencies(F.dcshift, waveform, shift=0.5, limiter_gain=0.05)
def test_overdrive(self): def test_overdrive(self):
waveform = torch.rand(2, 100) - 0.5 waveform = torch.rand(2, 100) - 0.5
_test_batch(F.overdrive, waveform, gain=45, colour=30) self.assert_batch_consistencies(F.overdrive, waveform, gain=45, colour=30)
def test_phaser(self): def test_phaser(self):
filepath = common_utils.get_asset_path("whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath) waveform, sample_rate = torchaudio.load(filepath)
_test_batch(F.phaser, waveform, sample_rate) self.assert_batch_consistencies(F.phaser, waveform, sample_rate)
def test_sliding_window_cmn(self): def test_sliding_window_cmn(self):
waveform = torch.randn(2, 1024) - 0.5 waveform = torch.randn(2, 1024) - 0.5
_test_batch(F.sliding_window_cmn, waveform, center=True, norm_vars=True) self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=True, norm_vars=True)
_test_batch(F.sliding_window_cmn, waveform, center=True, norm_vars=False) self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=True, norm_vars=False)
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=True) self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=True)
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=False) self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=False)
def test_vad(self): def test_vad(self):
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav") filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath) waveform, sample_rate = torchaudio.load(filepath)
_test_batch(F.vad, waveform, sample_rate=sample_rate) self.assert_batch_consistencies(F.vad, waveform, sample_rate=sample_rate)
class TestTransforms(unittest.TestCase): class TestTransforms(TestCase):
"""Test suite for classes defined in `transforms` module""" """Test suite for classes defined in `transforms` module"""
def test_batch_AmplitudeToDB(self): def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201)) spec = torch.rand((6, 201))
...@@ -106,7 +108,7 @@ class TestTransforms(unittest.TestCase): ...@@ -106,7 +108,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1)) computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
def test_batch_Resample(self): def test_batch_Resample(self):
waveform = torch.randn(2, 2786) waveform = torch.randn(2, 2786)
...@@ -117,7 +119,7 @@ class TestTransforms(unittest.TestCase): ...@@ -117,7 +119,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
def test_batch_MelScale(self): def test_batch_MelScale(self):
specgram = torch.randn(2, 31, 2786) specgram = torch.randn(2, 31, 2786)
...@@ -129,7 +131,7 @@ class TestTransforms(unittest.TestCase): ...@@ -129,7 +131,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1)) computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394) # shape = (3, 2, 201, 1394)
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
def test_batch_InverseMelScale(self): def test_batch_InverseMelScale(self):
n_mels = 32 n_mels = 32
...@@ -146,7 +148,7 @@ class TestTransforms(unittest.TestCase): ...@@ -146,7 +148,7 @@ class TestTransforms(unittest.TestCase):
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield # Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here. # exactly same result. For this reason, tolerance is very relaxed here.
torch.testing.assert_allclose(computed, expected, atol=1.0, rtol=1e-5) self.assertEqual(computed, expected, atol=1.0, rtol=1e-5)
def test_batch_compute_deltas(self): def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786) specgram = torch.randn(2, 31, 2786)
...@@ -158,7 +160,7 @@ class TestTransforms(unittest.TestCase): ...@@ -158,7 +160,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1)) computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394) # shape = (3, 2, 201, 1394)
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
def test_batch_mulaw(self): def test_batch_mulaw(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
...@@ -173,7 +175,7 @@ class TestTransforms(unittest.TestCase): ...@@ -173,7 +175,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MuLawEncoding()(waveform_batched) computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)
# shape = (3, 2, 201, 1394) # shape = (3, 2, 201, 1394)
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
# Single then transform then batch # Single then transform then batch
waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded) waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded)
...@@ -183,7 +185,7 @@ class TestTransforms(unittest.TestCase): ...@@ -183,7 +185,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MuLawDecoding()(computed) computed = torchaudio.transforms.MuLawDecoding()(computed)
# shape = (3, 2, 201, 1394) # shape = (3, 2, 201, 1394)
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
def test_batch_spectrogram(self): def test_batch_spectrogram(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
...@@ -194,7 +196,7 @@ class TestTransforms(unittest.TestCase): ...@@ -194,7 +196,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
def test_batch_melspectrogram(self): def test_batch_melspectrogram(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
...@@ -205,7 +207,7 @@ class TestTransforms(unittest.TestCase): ...@@ -205,7 +207,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
def test_batch_mfcc(self): def test_batch_mfcc(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
...@@ -216,7 +218,7 @@ class TestTransforms(unittest.TestCase): ...@@ -216,7 +218,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected, atol=1e-4, rtol=1e-5) self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
def test_batch_TimeStretch(self): def test_batch_TimeStretch(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
...@@ -250,7 +252,7 @@ class TestTransforms(unittest.TestCase): ...@@ -250,7 +252,7 @@ class TestTransforms(unittest.TestCase):
hop_length=512, hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1)) )(complex_specgrams.repeat(3, 1, 1, 1, 1))
torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5) self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
def test_batch_Fade(self): def test_batch_Fade(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
...@@ -263,7 +265,7 @@ class TestTransforms(unittest.TestCase): ...@@ -263,7 +265,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
def test_batch_Vol(self): def test_batch_Vol(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
...@@ -274,7 +276,7 @@ class TestTransforms(unittest.TestCase): ...@@ -274,7 +276,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected) self.assertEqual(computed, expected)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment