Unverified Commit b40aee5a authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Refactor batch consistency test in transforms (#1772)

parent 0f822179
"""Test numerical consistency among single input and batched input."""
import torch
import torchaudio
from parameterized import parameterized
from torchaudio import transforms as T
from torchaudio_unittest import common_utils
class TestTransforms(common_utils.TorchaudioTestCase):
"""Test suite for classes defined in `transforms` module"""
backend = 'default'
"""Test suite for classes defined in `transforms` module"""
def test_batch_AmplitudeToDB(self):
spec = torch.rand((2, 6, 201))
def assert_batch_consistency(
self, transform, batch, *args, atol=1e-8, rtol=1e-5, seed=42,
**kwargs):
n = batch.size(0)
# Single then transform then batch
expected = torchaudio.transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)
# Compute items separately, then batch the result
torch.random.manual_seed(seed)
items_input = batch.clone()
items_result = torch.stack([
transform(items_input[i], *args, **kwargs) for i in range(n)
])
# Batch then transform
computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))
# Batch the input and run
torch.random.manual_seed(seed)
batch_input = batch.clone()
batch_result = transform(batch_input, *args, **kwargs)
self.assertEqual(computed, expected)
self.assertEqual(items_input, batch_input, rtol=rtol, atol=atol)
self.assertEqual(items_result, batch_result, rtol=rtol, atol=atol)
def test_batch_Resample(self):
waveform = torch.randn(2, 2786)
def test_batch_AmplitudeToDB(self):
spec = torch.rand((3, 2, 6, 201))
transform = T.AmplitudeToDB()
# Single then transform then batch
expected = torchaudio.transforms.Resample()(waveform).repeat(3, 1, 1)
self.assert_batch_consistency(transform, spec)
# Batch then transform
computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1))
def test_batch_Resample(self):
waveform = torch.randn(3, 2, 2786)
transform = T.Resample()
self.assertEqual(computed, expected)
self.assert_batch_consistency(transform, waveform)
def test_batch_MelScale(self):
specgram = torch.randn(2, 201, 256)
# Single then transform then batch
expected = torchaudio.transforms.MelScale()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
specgram = torch.randn(3, 2, 201, 256)
transform = T.MelScale()
# shape = (3, 2, 128, 256)
self.assertEqual(computed, expected)
self.assert_batch_consistency(transform, specgram)
def test_batch_InverseMelScale(self):
n_mels = 32
n_stft = 5
mel_spec = torch.randn(2, n_mels, 32) ** 2
# Single then transform then batch
expected = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))
# shape = (3, 2, n_mels, 32)
mel_spec = torch.randn(3, 2, n_mels, 32) ** 2
transform = T.InverseMelScale(n_stft, n_mels)
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
self.assertEqual(computed, expected, atol=1.0, rtol=1e-5)
self.assert_batch_consistency(transform, mel_spec, atol=1.0, rtol=1e-5)
def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)
specgram = torch.randn(3, 2, 31, 2786)
transform = T.ComputeDeltas()
# Single then transform then batch
expected = torchaudio.transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
self.assertEqual(computed, expected)
self.assert_batch_consistency(transform, specgram)
def test_batch_mulaw(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
# Single then transform then batch
waveform_encoded = torchaudio.transforms.MuLawEncoding()(waveform)
expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)
expected = [T.MuLawEncoding()(waveform[i]) for i in range(3)]
expected = torch.stack(expected)
# Batch then transform
waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)
computed = T.MuLawEncoding()(waveform)
# shape = (3, 2, 201, 1394)
self.assertEqual(computed, expected)
# Single then transform then batch
waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded)
expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)
expected_decoded = [T.MuLawDecoding()(expected[i]) for i in range(3)]
expected_decoded = torch.stack(expected_decoded)
# Batch then transform
computed = torchaudio.transforms.MuLawDecoding()(computed)
computed_decoded = T.MuLawDecoding()(computed)
# shape = (3, 2, 201, 1394)
self.assertEqual(computed, expected)
self.assertEqual(computed_decoded, expected_decoded)
def test_batch_spectrogram(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2)
# Single then transform then batch
expected = torchaudio.transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.Spectrogram()
# Batch then transform
computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
self.assert_batch_consistency(transform, waveform)
def test_batch_inverse_spectrogram(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2)
transform = torchaudio.transforms.Spectrogram(power=None)(waveform)
# Single then transform then batch
expected = torchaudio.transforms.InverseSpectrogram()(transform).repeat(3, 1, 1)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1])
transform = T.InverseSpectrogram(n_fft=400)
# Batch then transform
computed = torchaudio.transforms.InverseSpectrogram()(transform.repeat(3, 1, 1, 1))
self.assertEqual(computed, expected)
self.assert_batch_consistency(transform, specgram)
def test_batch_melspectrogram(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.MelSpectrogram()
# Single then transform then batch
expected = torchaudio.transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
self.assert_batch_consistency(transform, waveform)
def test_batch_mfcc(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.MFCC()
# Single then transform then batch
expected = torchaudio.transforms.MFCC()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5)
def test_batch_lfcc(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2)
# Single then transform then batch
expected = torchaudio.transforms.LFCC()(waveform).repeat(3, 1, 1, 1)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.LFCC()
# Batch then transform
computed = torchaudio.transforms.LFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5)
@parameterized.expand([(True, ), (False, )])
def test_batch_TimeStretch(self, test_pseudo_complex):
rate = 2
num_freq = 1025
num_frames = 400
batch = 3
spec = torch.randn(num_freq, num_frames, dtype=torch.complex64)
pattern = [3, 1, 1, 1]
spec = torch.randn(batch, num_freq, num_frames, dtype=torch.complex64)
if test_pseudo_complex:
spec = torch.view_as_real(spec)
pattern += [1]
# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=num_freq,
hop_length=512,
)(spec).repeat(*pattern)
# Batch then transform
computed = torchaudio.transforms.TimeStretch(
transform = T.TimeStretch(
fixed_rate=rate,
n_freq=num_freq,
hop_length=512,
)(spec.repeat(*pattern))
hop_length=512
)
self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
self.assert_batch_consistency(transform, spec, atol=1e-5, rtol=1e-5)
def test_batch_Fade(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
fade_in_len = 3000
fade_out_len = 3000
transform = T.Fade(fade_in_len, fade_out_len)
# Single then transform then batch
expected = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
self.assert_batch_consistency(transform, waveform)
def test_batch_Vol(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.Vol(gain=1.1)
# Single then transform then batch
expected = torchaudio.transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
self.assert_batch_consistency(transform, waveform)
def test_batch_spectral_centroid(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.SpectralCentroid(sample_rate)
# Single then transform then batch
expected = torchaudio.transforms.SpectralCentroid(sample_rate)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.SpectralCentroid(sample_rate)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
self.assert_batch_consistency(transform, waveform)
def test_batch_pitch_shift(self):
sample_rate = 8000
n_steps = -2
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, duration=0.05)
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.PitchShift(sample_rate, n_steps, n_fft=400)
# Single then transform then batch
expected = torchaudio.transforms.PitchShift(sample_rate, n_steps, n_fft=400)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.PitchShift(sample_rate, n_steps, n_fft=400)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
self.assert_batch_consistency(transform, waveform)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment