Unverified Commit b8187b4c authored by moto's avatar moto Committed by GitHub
Browse files

Adopt PyTorch's test util to batch consistency test (#643)

parent 44af0dea
"""Test numerical consistency among single input and batched input."""
import unittest
import platform
import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.functional as F
import common_utils
def _test_batch_consistency(functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
class TestFunctional(TestCase):
"""Test functions defined in `functional` module"""
def assert_batch_consistency(
self, functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
# run then batch the result
torch.random.manual_seed(seed)
expected = functional(tensor.clone(), *args, **kwargs)
......@@ -20,16 +23,15 @@ def _test_batch_consistency(functional, tensor, *args, batch_size=1, atol=1e-8,
pattern = [batch_size] + [1] * tensor.dim()
computed = functional(tensor.repeat(pattern), *args, **kwargs)
torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol)
self.assertEqual(computed, expected, rtol=rtol, atol=atol)
def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
_test_batch_consistency(functional, tensor, *args, batch_size=1, atol=atol, rtol=rtol, seed=seed, **kwargs)
_test_batch_consistency(functional, tensor, *args, batch_size=3, atol=atol, rtol=rtol, seed=seed, **kwargs)
def assert_batch_consistencies(
self, functional, tensor, *args, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
self.assert_batch_consistency(
functional, tensor, *args, batch_size=1, atol=atol, rtol=rtol, seed=seed, **kwargs)
self.assert_batch_consistency(
functional, tensor, *args, batch_size=3, atol=atol, rtol=rtol, seed=seed, **kwargs)
class TestFunctional(unittest.TestCase):
"""Test functions defined in `functional` module"""
def test_griffinlim(self):
n_fft = 400
ws = 400
......@@ -41,7 +43,7 @@ class TestFunctional(unittest.TestCase):
n_iter = 32
length = 1000
tensor = torch.rand((1, 201, 6))
_test_batch(
self.assert_batch_consistencies(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
)
......@@ -55,7 +57,7 @@ class TestFunctional(unittest.TestCase):
for filename in filenames:
filepath = common_utils.get_asset_path(filename)
waveform, sample_rate = torchaudio.load(filepath)
_test_batch(F.detect_pitch_frequency, waveform, sample_rate)
self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate)
def test_istft(self):
stft = torch.tensor([
......@@ -63,39 +65,39 @@ class TestFunctional(unittest.TestCase):
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
_test_batch(F.istft, stft, n_fft=4, length=4)
self.assert_batch_consistencies(F.istft, stft, n_fft=4, length=4)
def test_contrast(self):
waveform = torch.rand(2, 100) - 0.5
_test_batch(F.contrast, waveform, enhancement_amount=80.)
self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.)
def test_dcshift(self):
waveform = torch.rand(2, 100) - 0.5
_test_batch(F.dcshift, waveform, shift=0.5, limiter_gain=0.05)
self.assert_batch_consistencies(F.dcshift, waveform, shift=0.5, limiter_gain=0.05)
def test_overdrive(self):
waveform = torch.rand(2, 100) - 0.5
_test_batch(F.overdrive, waveform, gain=45, colour=30)
self.assert_batch_consistencies(F.overdrive, waveform, gain=45, colour=30)
def test_phaser(self):
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath)
_test_batch(F.phaser, waveform, sample_rate)
self.assert_batch_consistencies(F.phaser, waveform, sample_rate)
def test_sliding_window_cmn(self):
waveform = torch.randn(2, 1024) - 0.5
_test_batch(F.sliding_window_cmn, waveform, center=True, norm_vars=True)
_test_batch(F.sliding_window_cmn, waveform, center=True, norm_vars=False)
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=True)
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=False)
self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=True, norm_vars=True)
self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=True, norm_vars=False)
self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=True)
self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=False)
def test_vad(self):
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath)
_test_batch(F.vad, waveform, sample_rate=sample_rate)
self.assert_batch_consistencies(F.vad, waveform, sample_rate=sample_rate)
class TestTransforms(unittest.TestCase):
class TestTransforms(TestCase):
"""Test suite for classes defined in `transforms` module"""
def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201))
......@@ -106,7 +108,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected)
self.assertEqual(computed, expected)
def test_batch_Resample(self):
waveform = torch.randn(2, 2786)
......@@ -117,7 +119,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected)
self.assertEqual(computed, expected)
def test_batch_MelScale(self):
specgram = torch.randn(2, 31, 2786)
......@@ -129,7 +131,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
torch.testing.assert_allclose(computed, expected)
self.assertEqual(computed, expected)
def test_batch_InverseMelScale(self):
n_mels = 32
......@@ -146,7 +148,7 @@ class TestTransforms(unittest.TestCase):
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
torch.testing.assert_allclose(computed, expected, atol=1.0, rtol=1e-5)
self.assertEqual(computed, expected, atol=1.0, rtol=1e-5)
def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)
......@@ -158,7 +160,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
torch.testing.assert_allclose(computed, expected)
self.assertEqual(computed, expected)
def test_batch_mulaw(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
......@@ -173,7 +175,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)
# shape = (3, 2, 201, 1394)
torch.testing.assert_allclose(computed, expected)
self.assertEqual(computed, expected)
# Single then transform then batch
waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded)
......@@ -183,7 +185,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MuLawDecoding()(computed)
# shape = (3, 2, 201, 1394)
torch.testing.assert_allclose(computed, expected)
self.assertEqual(computed, expected)
def test_batch_spectrogram(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
......@@ -194,7 +196,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected)
self.assertEqual(computed, expected)
def test_batch_melspectrogram(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
......@@ -205,7 +207,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected)
self.assertEqual(computed, expected)
def test_batch_mfcc(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
......@@ -216,7 +218,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected, atol=1e-4, rtol=1e-5)
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
def test_batch_TimeStretch(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
......@@ -250,7 +252,7 @@ class TestTransforms(unittest.TestCase):
hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1))
torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5)
self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
def test_batch_Fade(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
......@@ -263,7 +265,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected)
self.assertEqual(computed, expected)
def test_batch_Vol(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
......@@ -274,7 +276,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected)
self.assertEqual(computed, expected)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment