Unverified Commit fad19fab authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Ensure resampling identity is unchanged (#1537)

parent f1a0b605
......@@ -56,15 +56,12 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def setUp(self):
super().setUp()
# 1. test signal for testing resampling
self.test1_signal_sr = 16000
self.test1_signal = common_utils.get_whitenoise(
sample_rate=self.test1_signal_sr, duration=0.5,
# test signal for testing resampling
self.test_signal_sr = 16000
self.test_signal = common_utils.get_whitenoise(
sample_rate=self.test_signal_sr, duration=0.5,
)
# 2. test audio file corresponding to saved kaldi ark files
self.test2_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
# separating test files by their types (e.g 'spec', 'fbank', etc.)
for f in os.listdir(kaldi_output_dir):
dash_idx = f.find('-')
......@@ -176,30 +173,23 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
# Passing in an empty tensor should result in an error
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))
def test_resample_waveform(self):
def get_output_fn(sound, args):
output = kaldi.resample_waveform(sound.to(torch.float32), args[1], args[2])
return output
self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_size(self, resampling_method):
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2,
upsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr * 2,
resampling_method=resampling_method)
self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2)
self.assertTrue(upsample_sound.size(-1) == self.test_signal.size(-1) * 2)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2,
downsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr // 2,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2)
self.assertTrue(downsample_sound.size(-1) == self.test_signal.size(-1) // 2)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_identity_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr,
downsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1))
self.assertTrue(downsample_sound.size(-1) == self.test_signal.size(-1))
def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4):
......@@ -244,18 +234,18 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def test_resample_waveform_multi_channel(self, resampling_method):
num_channels = 3
multi_sound = self.test1_signal.repeat(num_channels, 1) # (num_channels, 8000 smp)
multi_sound = self.test_signal.repeat(num_channels, 1) # (num_channels, 8000 smp)
for i in range(num_channels):
multi_sound[i, :] *= (i + 1) * 1.5
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2,
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test_signal_sr, self.test_signal_sr // 2,
resampling_method=resampling_method)
# check that sampling is same whether using separately or in a tensor of size (c, n)
for i in range(num_channels):
single_channel = self.test1_signal * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr,
self.test1_signal_sr // 2,
single_channel = self.test_signal * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test_signal_sr,
self.test_signal_sr // 2,
resampling_method=resampling_method)
self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)
......@@ -259,6 +259,16 @@ class Functional(TestBaseMixin):
self.assertEqual(specgrams, specgrams_copy)
@parameterized.expand(list(itertools.product(
["sinc_interpolation", "kaiser_window"],
[16000, 44100],
)))
def test_resample_identity(self, resampling_method, sample_rate):
waveform = get_whitenoise(sample_rate=sample_rate, duration=1)
resampled = F.resample(waveform, sample_rate, sample_rate)
self.assertEqual(waveform, resampled)
def test_resample_no_warning(self):
sample_rate = 44100
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1)
......
import itertools
import warnings
import torch
......@@ -8,6 +9,7 @@ from torchaudio_unittest.common_utils import (
get_whitenoise,
get_spectrogram,
)
from parameterized import parameterized
def _get_ratio(mat):
......@@ -77,3 +79,14 @@ class TransformsTestBase(TestBaseMixin):
warnings.simplefilter("always")
T.MelScale(n_mels=64, sample_rate=8000, n_stft=201)
assert len(caught_warnings) == 0
@parameterized.expand(list(itertools.product(
["sinc_interpolation", "kaiser_window"],
[16000, 44100],
)))
def test_resample_identity(self, resampling_method, sample_rate):
waveform = get_whitenoise(sample_rate=sample_rate, duration=1)
resampler = T.Resample(sample_rate, sample_rate)
resampled = resampler(waveform)
self.assertEqual(waveform, resampled)
......@@ -1449,6 +1449,9 @@ def resample(
assert orig_freq > 0.0 and new_freq > 0.0
if orig_freq == new_freq:
return waveform
gcd = math.gcd(int(orig_freq), int(new_freq))
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff,
......
......@@ -696,10 +696,11 @@ class Resample(torch.nn.Module):
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff
kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
self.register_buffer('kernel', kernel)
if self.orig_freq != self.new_freq:
kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
self.register_buffer('kernel', kernel)
def forward(self, waveform: Tensor) -> Tensor:
r"""
......@@ -709,6 +710,8 @@ class Resample(torch.nn.Module):
Returns:
Tensor: Output signal of dimension (..., time).
"""
if self.orig_freq == self.new_freq:
return waveform
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment