Unverified Commit d3f967e9 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

add batch test to TimeStretch (#459)

parent 9efc3503
...@@ -519,6 +519,36 @@ class Tester(unittest.TestCase): ...@@ -519,6 +519,36 @@ class Tester(unittest.TestCase):
tensor = torch.rand((10, 2, n_freq, 10, 2)) tensor = torch.rand((10, 2, n_freq, 10, 2))
_test_script_module(transforms.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate) _test_script_module(transforms.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)
def test_batch_TimeStretch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
kwargs = {
'n_fft': 2048,
'hop_length': 512,
'win_length': 2048,
'window': torch.hann_window(2048),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}
rate = 2
complex_specgrams = torch.stft(waveform, **kwargs)
# Single then transform then batch
expected = transforms.TimeStretch(fixed_rate=rate,
n_freq=1025,
hop_length=512)(complex_specgrams).repeat(3, 1, 1, 1, 1)
# Batch then transform
computed = transforms.TimeStretch(fixed_rate=rate,
n_freq=1025,
hop_length=512)(complex_specgrams.repeat(3, 1, 1, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected, atol=1e-5))
def test_batch_Fade(self): def test_batch_Fade(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) waveform, sample_rate = torchaudio.load(self.test_filepath)
fade_in_len = 3000 fade_in_len = 3000
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment