Unverified Commit 5cf3e56f authored by moto's avatar moto Committed by GitHub
Browse files

Extract batch test from test_transforms and move to the dedicated module (#501)

parent 0f8fa5f8
......@@ -7,6 +7,7 @@ import torchaudio
import torchaudio.functional as F
import common_utils
from common_utils import AudioBackendScope, BACKENDS
def _test_batch_shape(functional, tensor, *args, **kwargs):
......@@ -102,3 +103,212 @@ class TestFunctional(unittest.TestCase):
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
_test_batch(F.istft, stft, n_fft=4, length=4)
class TestTransforms(unittest.TestCase):
"""Test suite for classes defined in `transforms` module"""
def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201))
# Single then transform then batch
expected = torchaudio.transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
def test_batch_Resample(self):
waveform = torch.randn(2, 2786)
# Single then transform then batch
expected = torchaudio.transforms.Resample()(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
def test_batch_MelScale(self):
specgram = torch.randn(2, 31, 2786)
# Single then transform then batch
expected = torchaudio.transforms.MelScale()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
def test_batch_InverseMelScale(self):
n_mels = 32
n_stft = 5
mel_spec = torch.randn(2, n_mels, 32) ** 2
# Single then transform then batch
expected = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))
# shape = (3, 2, n_mels, 32)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
assert torch.allclose(computed, expected, atol=1.0)
def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)
# Single then transform then batch
expected = torchaudio.transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
def test_batch_mulaw(self):
test_filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
# Single then transform then batch
waveform_encoded = torchaudio.transforms.MuLawEncoding()(waveform)
expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)
# shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
# Single then transform then batch
waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded)
expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.MuLawDecoding()(computed)
# shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
def test_batch_spectrogram(self):
test_filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
# Single then transform then batch
expected = torchaudio.transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
def test_batch_melspectrogram(self):
test_filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
# Single then transform then batch
expected = torchaudio.transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_batch_mfcc(self):
test_filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.mp3')
waveform, _ = torchaudio.load(test_filepath)
# Single then transform then batch
expected = torchaudio.transforms.MFCC()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected, atol=1e-5)
def test_batch_TimeStretch(self):
test_filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
kwargs = {
'n_fft': 2048,
'hop_length': 512,
'win_length': 2048,
'window': torch.hann_window(2048),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}
rate = 2
complex_specgrams = torch.stft(waveform, **kwargs)
# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams).repeat(3, 1, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected, atol=1e-5)
def test_batch_Fade(self):
test_filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
fade_in_len = 3000
fade_out_len = 3000
# Single then transform then batch
expected = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
def test_batch_Vol(self):
test_filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
# Single then transform then batch
expected = torchaudio.transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
......@@ -44,18 +44,6 @@ class Tester(unittest.TestCase):
waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201))
# Single then transform then batch
expected = transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)
# Batch then transform
computed = transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_AmplitudeToDB(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
......@@ -175,18 +163,6 @@ class Tester(unittest.TestCase):
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
def test_batch_Resample(self):
waveform = torch.randn(2, 2786)
# Single then transform then batch
expected = transforms.Resample()(waveform).repeat(3, 1, 1)
# Batch then transform
computed = transforms.Resample()(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_resample_size(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
waveform, sample_rate = torchaudio.load(input_path)
......@@ -242,174 +218,6 @@ class Tester(unittest.TestCase):
computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
def test_batch_MelScale(self):
specgram = torch.randn(2, 31, 2786)
# Single then transform then batch
expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_InverseMelScale(self):
n_fft = 8
n_mels = 32
n_stft = 5
mel_spec = torch.randn(2, n_mels, 32) ** 2
# Single then transform then batch
expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))
# shape = (3, 2, n_mels, 32)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
self.assertTrue(torch.allclose(computed, expected, atol=1.0))
def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)
# Single then transform then batch
expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_mulaw(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
# Single then transform then batch
waveform_encoded = transforms.MuLawEncoding()(waveform)
expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = transforms.MuLawEncoding()(waveform_batched)
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
# Single then transform then batch
waveform_decoded = transforms.MuLawDecoding()(waveform_encoded)
expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)
# Batch then transform
computed = transforms.MuLawDecoding()(computed)
# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_spectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_melspectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_batch_mfcc(self):
test_filepath = os.path.join(
self.test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3'
)
waveform, sample_rate = torchaudio.load(test_filepath)
# Single then transform then batch
expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected, atol=1e-5))
def test_batch_TimeStretch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
kwargs = {
'n_fft': 2048,
'hop_length': 512,
'win_length': 2048,
'window': torch.hann_window(2048),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}
rate = 2
complex_specgrams = torch.stft(waveform, **kwargs)
# Single then transform then batch
expected = transforms.TimeStretch(fixed_rate=rate,
n_freq=1025,
hop_length=512)(complex_specgrams).repeat(3, 1, 1, 1, 1)
# Batch then transform
computed = transforms.TimeStretch(fixed_rate=rate,
n_freq=1025,
hop_length=512)(complex_specgrams.repeat(3, 1, 1, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected, atol=1e-5))
def test_batch_Fade(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
fade_in_len = 3000
fade_out_len = 3000
# Single then transform then batch
expected = transforms.Fade(fade_in_len, fade_out_len)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_Vol(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment