Commit dc452aab authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Multichannel Resample (#154)

parent 629db65d
...@@ -267,6 +267,22 @@ class Test_Kaldi(unittest.TestCase): ...@@ -267,6 +267,22 @@ class Test_Kaldi(unittest.TestCase):
for i in range(1, 20): for i in range(1, 20):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0) self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0)
def test_resample_waveform_multi_channel(self):
num_channels = 3
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath) # (1, 8000)
multi_sound = sound.repeat(num_channels, 1) # (num_channels, 8000)
for i in range(num_channels):
multi_sound[i, :] *= (i + 1) * 1.5
multi_sound_sampled = kaldi.resample_waveform(multi_sound, sample_rate, sample_rate // 2)
# check that sampling is same whether using separately or in a tensor of size (c, n)
for i in range(num_channels):
single_channel = sound * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2)
self.assertTrue(torch.allclose(multi_sound_sampled[i, :], single_channel_sampled, rtol=1e-4))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -711,7 +711,8 @@ def resample_waveform(wave, orig_freq, new_freq, lowpass_filter_width=6): ...@@ -711,7 +711,8 @@ def resample_waveform(wave, orig_freq, new_freq, lowpass_filter_width=6):
wave_to_conv = torch.nn.functional.pad(wave_to_conv, (left_padding, right_padding)) wave_to_conv = torch.nn.functional.pad(wave_to_conv, (left_padding, right_padding))
conv_wave = torch.nn.functional.conv1d( conv_wave = torch.nn.functional.conv1d(
wave_to_conv.unsqueeze(0), weights[i].view(1, 1, window_size), stride=conv_stride) wave_to_conv.unsqueeze(0), weights[i].repeat(num_channels, 1, 1),
stride=conv_stride, groups=num_channels)
# we want conv_wave[:, i] to be at output[:, i + n*conv_transpose_stride] # we want conv_wave[:, i] to be at output[:, i + n*conv_transpose_stride]
dilated_conv_wave = torch.nn.functional.conv_transpose1d( dilated_conv_wave = torch.nn.functional.conv_transpose1d(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment