Unverified Commit 445e14d1 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

batch resample transform (#435)

parent ffeee199
...@@ -326,6 +326,18 @@ class Tester(unittest.TestCase): ...@@ -326,6 +326,18 @@ class Tester(unittest.TestCase):
_test_script_module(transforms.Spectrogram, tensor, sample_rate, sample_rate_2) _test_script_module(transforms.Spectrogram, tensor, sample_rate, sample_rate_2)
def test_batch_Resample(self):
waveform = torch.randn(2, 2786)
# Single then transform then batch
expected = transforms.Resample()(waveform).repeat(3, 1, 1)
# Batch then transform
computed = transforms.Resample()(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_scriptmodule_ComplexNorm(self): def test_scriptmodule_ComplexNorm(self):
tensor = torch.rand((1, 2, 201, 2)) tensor = torch.rand((1, 2, 201, 2))
_test_script_module(transforms.ComplexNorm, tensor) _test_script_module(transforms.ComplexNorm, tensor)
......
...@@ -445,7 +445,17 @@ class Resample(torch.nn.Module): ...@@ -445,7 +445,17 @@ class Resample(torch.nn.Module):
torch.Tensor: Output signal of dimension (..., time) torch.Tensor: Output signal of dimension (..., time)
""" """
if self.resampling_method == 'sinc_interpolation': if self.resampling_method == 'sinc_interpolation':
return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
# unpack batch
waveform = waveform.view(shape[:-1] + waveform.shape[-1:])
return waveform
raise ValueError('Invalid resampling method: %s' % (self.resampling_method)) raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment