Unverified Commit 15a7f78c authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Migrate resample tests from kaldi to functional (#1520)

parent 68823423
...@@ -7,7 +7,6 @@ import torchaudio.compliance.kaldi as kaldi ...@@ -7,7 +7,6 @@ import torchaudio.compliance.kaldi as kaldi
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .compliance import utils as compliance_utils from .compliance import utils as compliance_utils
from parameterized import parameterized
def extract_window(window, wave, f, frame_length, frame_shift, snip_edges): def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
...@@ -53,15 +52,6 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -53,15 +52,6 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
test_filepath = common_utils.get_asset_path('kaldi_file.wav') test_filepath = common_utils.get_asset_path('kaldi_file.wav')
test_filepaths = {prefix: [] for prefix in compliance_utils.TEST_PREFIX} test_filepaths = {prefix: [] for prefix in compliance_utils.TEST_PREFIX}
def setUp(self):
super().setUp()
# test signal for testing resampling
self.test_signal_sr = 16000
self.test_signal = common_utils.get_whitenoise(
sample_rate=self.test_signal_sr, duration=0.5,
)
# separating test files by their types (e.g 'spec', 'fbank', etc.) # separating test files by their types (e.g 'spec', 'fbank', etc.)
for f in os.listdir(kaldi_output_dir): for f in os.listdir(kaldi_output_dir):
dash_idx = f.find('-') dash_idx = f.find('-')
...@@ -172,80 +162,3 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -172,80 +162,3 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def test_mfcc_empty(self): def test_mfcc_empty(self):
# Passing in an empty tensor should result in an error # Passing in an empty tensor should result in an error
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0)) self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_size(self, resampling_method):
upsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr * 2,
resampling_method=resampling_method)
self.assertTrue(upsample_sound.size(-1) == self.test_signal.size(-1) * 2)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr // 2,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test_signal.size(-1) // 2)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_identity_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test_signal.size(-1))
def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4):
# resample the signal and compare it to the ground truth
n_to_trim = 20
sample_rate = 1000
new_sample_rate = sample_rate
if up_scale_factor is not None:
new_sample_rate *= up_scale_factor
if down_scale_factor is not None:
new_sample_rate //= down_scale_factor
duration = 5 # seconds
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)
sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate,
resampling_method=resampling_method).squeeze()
new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)
# trim the first/last n samples as these points have boundary effects
ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
estimate = estimate[..., n_to_trim:-n_to_trim]
self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_accuracy(self, resampling_method):
for i in range(1, 20):
self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_accuracy(self, resampling_method):
for i in range(1, 20):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_multi_channel(self, resampling_method):
num_channels = 3
multi_sound = self.test_signal.repeat(num_channels, 1) # (num_channels, 8000 smp)
for i in range(num_channels):
multi_sound[i, :] *= (i + 1) * 1.5
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test_signal_sr, self.test_signal_sr // 2,
resampling_method=resampling_method)
# check that sampling is same whether using separately or in a tensor of size (c, n)
for i in range(num_channels):
single_channel = self.test_signal * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test_signal_sr,
self.test_signal_sr // 2,
resampling_method=resampling_method)
self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)
...@@ -197,6 +197,17 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -197,6 +197,17 @@ class TestFunctional(common_utils.TorchaudioTestCase):
F.sliding_window_cmn, spectrogram, center=center, F.sliding_window_cmn, spectrogram, center=center,
norm_vars=norm_vars) norm_vars=norm_vars)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform(self, resampling_method):
num_channels = 3
sr = 16000
new_sr = sr // 2
multi_sound = common_utils.get_whitenoise(sample_rate=sr, n_channels=num_channels, duration=0.5,)
self.assert_batch_consistency(
F.resample, multi_sound, orig_freq=sr, new_freq=new_sr,
resampling_method=resampling_method, rtol=1e-4, atol=1e-7)
@common_utils.skipIfNoKaldi @common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self): def test_compute_kaldi_pitch(self):
sample_rate = 44100 sample_rate = 44100
......
...@@ -13,6 +13,35 @@ from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested ...@@ -13,6 +13,35 @@ from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested
class Functional(TestBaseMixin): class Functional(TestBaseMixin):
def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4):
# resample the signal and compare it to the ground truth
n_to_trim = 20
sample_rate = 1000
new_sample_rate = sample_rate
if up_scale_factor is not None:
new_sample_rate *= up_scale_factor
if down_scale_factor is not None:
new_sample_rate //= down_scale_factor
duration = 5 # seconds
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)
sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
estimate = F.resample(sound, sample_rate, new_sample_rate,
resampling_method=resampling_method).squeeze()
new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)
# trim the first/last n samples as these points have boundary effects
ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
estimate = estimate[..., n_to_trim:-n_to_trim]
self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)
def test_lfilter_simple(self): def test_lfilter_simple(self):
""" """
Create a very basic signal, Create a very basic signal,
...@@ -269,6 +298,41 @@ class Functional(TestBaseMixin): ...@@ -269,6 +298,41 @@ class Functional(TestBaseMixin):
resampled = F.resample(waveform, sample_rate, sample_rate) resampled = F.resample(waveform, sample_rate, sample_rate)
self.assertEqual(waveform, resampled) self.assertEqual(waveform, resampled)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_size(self, resampling_method):
sr = 16000
waveform = get_whitenoise(sample_rate=sr, duration=0.5,)
upsampled = F.resample(waveform, sr, sr * 2, resampling_method=resampling_method)
assert upsampled.size(-1) == waveform.size(-1) * 2
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_size(self, resampling_method):
sr = 16000
waveform = get_whitenoise(sample_rate=sr, duration=0.5,)
downsampled = F.resample(waveform, sr, sr // 2, resampling_method=resampling_method)
assert downsampled.size(-1) == waveform.size(-1) // 2
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_identity_size(self, resampling_method):
sr = 16000
waveform = get_whitenoise(sample_rate=sr, duration=0.5,)
resampled = F.resample(waveform, sr, sr, resampling_method=resampling_method)
assert resampled.size(-1) == waveform.size(-1)
@parameterized.expand(list(itertools.product(
["sinc_interpolation", "kaiser_window"],
list(range(1, 20)),
)))
def test_resample_waveform_downsample_accuracy(self, resampling_method, i):
self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method)
@parameterized.expand(list(itertools.product(
["sinc_interpolation", "kaiser_window"],
list(range(1, 20)),
)))
def test_resample_waveform_upsample_accuracy(self, resampling_method, i):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)
def test_resample_no_warning(self): def test_resample_no_warning(self):
sample_rate = 44100 sample_rate = 44100
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1) waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment