Unverified Commit b40aee5a authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Refactor batch consistency test in transforms (#1772)

parent 0f822179
"""Test numerical consistency among single input and batched input.""" """Test numerical consistency among single input and batched input."""
import torch import torch
import torchaudio
from parameterized import parameterized from parameterized import parameterized
from torchaudio import transforms as T
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
class TestTransforms(common_utils.TorchaudioTestCase): class TestTransforms(common_utils.TorchaudioTestCase):
"""Test suite for classes defined in `transforms` module"""
backend = 'default' backend = 'default'
"""Test suite for classes defined in `transforms` module""" def assert_batch_consistency(
def test_batch_AmplitudeToDB(self): self, transform, batch, *args, atol=1e-8, rtol=1e-5, seed=42,
spec = torch.rand((2, 6, 201)) **kwargs):
n = batch.size(0)
# Single then transform then batch # Compute items separately, then batch the result
expected = torchaudio.transforms.AmplitudeToDB()(spec).repeat(3, 1, 1) torch.random.manual_seed(seed)
items_input = batch.clone()
items_result = torch.stack([
transform(items_input[i], *args, **kwargs) for i in range(n)
])
# Batch then transform # Batch the input and run
computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1)) torch.random.manual_seed(seed)
batch_input = batch.clone()
batch_result = transform(batch_input, *args, **kwargs)
self.assertEqual(computed, expected) self.assertEqual(items_input, batch_input, rtol=rtol, atol=atol)
self.assertEqual(items_result, batch_result, rtol=rtol, atol=atol)
def test_batch_Resample(self): def test_batch_AmplitudeToDB(self):
waveform = torch.randn(2, 2786) spec = torch.rand((3, 2, 6, 201))
transform = T.AmplitudeToDB()
# Single then transform then batch self.assert_batch_consistency(transform, spec)
expected = torchaudio.transforms.Resample()(waveform).repeat(3, 1, 1)
# Batch then transform def test_batch_Resample(self):
computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1)) waveform = torch.randn(3, 2, 2786)
transform = T.Resample()
self.assertEqual(computed, expected) self.assert_batch_consistency(transform, waveform)
def test_batch_MelScale(self): def test_batch_MelScale(self):
specgram = torch.randn(2, 201, 256) specgram = torch.randn(3, 2, 201, 256)
transform = T.MelScale()
# Single then transform then batch
expected = torchaudio.transforms.MelScale()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 128, 256) self.assert_batch_consistency(transform, specgram)
self.assertEqual(computed, expected)
def test_batch_InverseMelScale(self): def test_batch_InverseMelScale(self):
n_mels = 32 n_mels = 32
n_stft = 5 n_stft = 5
mel_spec = torch.randn(2, n_mels, 32) ** 2 mel_spec = torch.randn(3, 2, n_mels, 32) ** 2
transform = T.InverseMelScale(n_stft, n_mels)
# Single then transform then batch
expected = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))
# shape = (3, 2, n_mels, 32)
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield # Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here. # exactly same result. For this reason, tolerance is very relaxed here.
self.assertEqual(computed, expected, atol=1.0, rtol=1e-5) self.assert_batch_consistency(transform, mel_spec, atol=1.0, rtol=1e-5)
def test_batch_compute_deltas(self): def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786) specgram = torch.randn(3, 2, 31, 2786)
transform = T.ComputeDeltas()
# Single then transform then batch self.assert_batch_consistency(transform, specgram)
expected = torchaudio.transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
self.assertEqual(computed, expected)
def test_batch_mulaw(self): def test_batch_mulaw(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
# Single then transform then batch # Single then transform then batch
waveform_encoded = torchaudio.transforms.MuLawEncoding()(waveform) expected = [T.MuLawEncoding()(waveform[i]) for i in range(3)]
expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1) expected = torch.stack(expected)
# Batch then transform # Batch then transform
waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1) computed = T.MuLawEncoding()(waveform)
computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)
# shape = (3, 2, 201, 1394) # shape = (3, 2, 201, 1394)
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
# Single then transform then batch # Single then transform then batch
waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded) expected_decoded = [T.MuLawDecoding()(expected[i]) for i in range(3)]
expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1) expected_decoded = torch.stack(expected_decoded)
# Batch then transform # Batch then transform
computed = torchaudio.transforms.MuLawDecoding()(computed) computed_decoded = T.MuLawDecoding()(computed)
# shape = (3, 2, 201, 1394) # shape = (3, 2, 201, 1394)
self.assertEqual(computed, expected) self.assertEqual(computed_decoded, expected_decoded)
def test_batch_spectrogram(self): def test_batch_spectrogram(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
# Single then transform then batch transform = T.Spectrogram()
expected = torchaudio.transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)
# Batch then transform self.assert_batch_consistency(transform, waveform)
computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
def test_batch_inverse_spectrogram(self): def test_batch_inverse_spectrogram(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
transform = torchaudio.transforms.Spectrogram(power=None)(waveform) specgram = common_utils.get_spectrogram(waveform, n_fft=400)
specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1])
# Single then transform then batch transform = T.InverseSpectrogram(n_fft=400)
expected = torchaudio.transforms.InverseSpectrogram()(transform).repeat(3, 1, 1)
# Batch then transform self.assert_batch_consistency(transform, specgram)
computed = torchaudio.transforms.InverseSpectrogram()(transform.repeat(3, 1, 1, 1))
self.assertEqual(computed, expected)
def test_batch_melspectrogram(self): def test_batch_melspectrogram(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.MelSpectrogram()
# Single then transform then batch self.assert_batch_consistency(transform, waveform)
expected = torchaudio.transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
def test_batch_mfcc(self): def test_batch_mfcc(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.MFCC()
# Single then transform then batch self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5)
expected = torchaudio.transforms.MFCC()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
def test_batch_lfcc(self): def test_batch_lfcc(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
# Single then transform then batch transform = T.LFCC()
expected = torchaudio.transforms.LFCC()(waveform).repeat(3, 1, 1, 1)
# Batch then transform self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5)
computed = torchaudio.transforms.LFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
@parameterized.expand([(True, ), (False, )]) @parameterized.expand([(True, ), (False, )])
def test_batch_TimeStretch(self, test_pseudo_complex): def test_batch_TimeStretch(self, test_pseudo_complex):
rate = 2 rate = 2
num_freq = 1025 num_freq = 1025
num_frames = 400 num_frames = 400
batch = 3
spec = torch.randn(num_freq, num_frames, dtype=torch.complex64) spec = torch.randn(batch, num_freq, num_frames, dtype=torch.complex64)
pattern = [3, 1, 1, 1]
if test_pseudo_complex: if test_pseudo_complex:
spec = torch.view_as_real(spec) spec = torch.view_as_real(spec)
pattern += [1]
# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=num_freq,
hop_length=512,
)(spec).repeat(*pattern)
# Batch then transform transform = T.TimeStretch(
computed = torchaudio.transforms.TimeStretch(
fixed_rate=rate, fixed_rate=rate,
n_freq=num_freq, n_freq=num_freq,
hop_length=512, hop_length=512
)(spec.repeat(*pattern)) )
self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5) self.assert_batch_consistency(transform, spec, atol=1e-5, rtol=1e-5)
def test_batch_Fade(self): def test_batch_Fade(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
fade_in_len = 3000 fade_in_len = 3000
fade_out_len = 3000 fade_out_len = 3000
transform = T.Fade(fade_in_len, fade_out_len)
# Single then transform then batch self.assert_batch_consistency(transform, waveform)
expected = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
def test_batch_Vol(self): def test_batch_Vol(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.Vol(gain=1.1)
# Single then transform then batch self.assert_batch_consistency(transform, waveform)
expected = torchaudio.transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
def test_batch_spectral_centroid(self): def test_batch_spectral_centroid(self):
sample_rate = 44100 sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate) waveform = common_utils.get_whitenoise(sample_rate=sample_rate, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.SpectralCentroid(sample_rate)
# Single then transform then batch self.assert_batch_consistency(transform, waveform)
expected = torchaudio.transforms.SpectralCentroid(sample_rate)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.SpectralCentroid(sample_rate)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
def test_batch_pitch_shift(self): def test_batch_pitch_shift(self):
sample_rate = 8000 sample_rate = 8000
n_steps = -2 n_steps = -2
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, duration=0.05) waveform = common_utils.get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.PitchShift(sample_rate, n_steps, n_fft=400)
# Single then transform then batch self.assert_batch_consistency(transform, waveform)
expected = torchaudio.transforms.PitchShift(sample_rate, n_steps, n_fft=400)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.PitchShift(sample_rate, n_steps, n_fft=400)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment