Unverified Commit 7078fcd3 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Add kaiser window support to resampling (#1509)

parent b8b732af
...@@ -7,6 +7,7 @@ import torchaudio.compliance.kaldi as kaldi ...@@ -7,6 +7,7 @@ import torchaudio.compliance.kaldi as kaldi
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .compliance import utils as compliance_utils from .compliance import utils as compliance_utils
from parameterized import parameterized
def extract_window(window, wave, f, frame_length, frame_shift, snip_edges): def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
...@@ -182,20 +183,26 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -182,20 +183,26 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5) self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)
def test_resample_waveform_upsample_size(self): @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2) def test_resample_waveform_upsample_size(self, resampling_method):
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2,
resampling_method=resampling_method)
self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2) self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2)
def test_resample_waveform_downsample_size(self): @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2) def test_resample_waveform_downsample_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2) self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2)
def test_resample_waveform_identity_size(self): @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr) def test_resample_waveform_identity_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1)) self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1))
def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None, def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
atol=1e-1, rtol=1e-4): resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4):
# resample the signal and compare it to the ground truth # resample the signal and compare it to the ground truth
n_to_trim = 20 n_to_trim = 20
sample_rate = 1000 sample_rate = 1000
...@@ -211,7 +218,8 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -211,7 +218,8 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate) original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)
sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0) sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze() estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate,
resampling_method=resampling_method).squeeze()
new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)] new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps) ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)
...@@ -222,15 +230,18 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -222,15 +230,18 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol) self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)
def test_resample_waveform_downsample_accuracy(self): @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_accuracy(self, resampling_method):
for i in range(1, 20): for i in range(1, 20):
self._test_resample_waveform_accuracy(down_scale_factor=i * 2) self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method)
def test_resample_waveform_upsample_accuracy(self): @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_accuracy(self, resampling_method):
for i in range(1, 20): for i in range(1, 20):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0) self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)
def test_resample_waveform_multi_channel(self): @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_multi_channel(self, resampling_method):
num_channels = 3 num_channels = 3
multi_sound = self.test1_signal.repeat(num_channels, 1) # (num_channels, 8000 smp) multi_sound = self.test1_signal.repeat(num_channels, 1) # (num_channels, 8000 smp)
...@@ -238,11 +249,13 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -238,11 +249,13 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
for i in range(num_channels): for i in range(num_channels):
multi_sound[i, :] *= (i + 1) * 1.5 multi_sound[i, :] *= (i + 1) * 1.5
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2) multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2,
resampling_method=resampling_method)
# check that sampling is same whether using separately or in a tensor of size (c, n) # check that sampling is same whether using separately or in a tensor of size (c, n)
for i in range(num_channels): for i in range(num_channels):
single_channel = self.test1_signal * (i + 1) * 1.5 single_channel = self.test1_signal * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr, single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr,
self.test1_signal_sr // 2) self.test1_signal_sr // 2,
resampling_method=resampling_method)
self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7) self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)
...@@ -169,9 +169,11 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -169,9 +169,11 @@ class Tester(common_utils.TorchaudioTestCase):
upsample_rate = sample_rate * 2 upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2 downsample_rate = sample_rate // 2
invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo') invalid_resampling_method = 'foo'
self.assertRaises(ValueError, invalid_resample, waveform) with self.assertRaises(ValueError):
torchaudio.transforms.Resample(sample_rate, upsample_rate,
resampling_method=invalid_resampling_method)
upsample_resample = torchaudio.transforms.Resample( upsample_resample = torchaudio.transforms.Resample(
sample_rate, upsample_rate, resampling_method='sinc_interpolation') sample_rate, upsample_rate, resampling_method='sinc_interpolation')
......
...@@ -756,7 +756,8 @@ def resample_waveform(waveform: Tensor, ...@@ -756,7 +756,8 @@ def resample_waveform(waveform: Tensor,
orig_freq: float, orig_freq: float,
new_freq: float, new_freq: float,
lowpass_filter_width: int = 6, lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> Tensor: rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation") -> Tensor:
r"""Resamples the waveform at the new frequency. r"""Resamples the waveform at the new frequency.
This is a wrapper around ``torchaudio.functional.resample``. This is a wrapper around ``torchaudio.functional.resample``.
...@@ -773,4 +774,5 @@ def resample_waveform(waveform: Tensor, ...@@ -773,4 +774,5 @@ def resample_waveform(waveform: Tensor,
Returns: Returns:
Tensor: The waveform at the new frequency Tensor: The waveform at the new frequency
""" """
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width, rolloff) return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width,
rolloff, resampling_method)
...@@ -1303,7 +1303,9 @@ def _get_sinc_resample_kernel( ...@@ -1303,7 +1303,9 @@ def _get_sinc_resample_kernel(
new_freq: float, new_freq: float,
gcd: int, gcd: int,
lowpass_filter_width: int, lowpass_filter_width: int,
rolloff: float): rolloff: float,
resampling_method: str,
beta: Optional[float]):
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn( warnings.warn(
...@@ -1318,9 +1320,15 @@ def _get_sinc_resample_kernel( ...@@ -1318,9 +1320,15 @@ def _get_sinc_resample_kernel(
"https://github.com/pytorch/audio/issues/1487." "https://github.com/pytorch/audio/issues/1487."
) )
if resampling_method not in ['sinc_interpolation', 'kaiser_window']:
raise ValueError('Invalid resampling method: {}'.format(resampling_method))
orig_freq = int(orig_freq) // gcd orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd new_freq = int(new_freq) // gcd
if resampling_method == "kaiser_window" and beta is None:
beta = 14.769656459379492
assert lowpass_filter_width > 0 assert lowpass_filter_width > 0
kernels = [] kernels = []
base_freq = min(orig_freq, new_freq) base_freq = min(orig_freq, new_freq)
...@@ -1352,15 +1360,20 @@ def _get_sinc_resample_kernel( ...@@ -1352,15 +1360,20 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right... # they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for # There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work. # future work.
idx = torch.arange(-width, width + orig_freq) idx = torch.arange(-width, width + orig_freq, dtype=torch.float64)
for i in range(new_freq): for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq t = (-i / new_freq + idx / orig_freq) * base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
t *= math.pi
# we do not use torch.hann_window here as we need to evaluate the window # we do not use built in torch windows here as we need to evaluate the window
# at specific positions, not over a regular grid. # at specific positions, not over a regular grid.
window = torch.cos(t / lowpass_filter_width / 2)**2 if resampling_method == "sinc_interpolation":
window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2
elif resampling_method == "kaiser_window":
beta = torch.tensor(beta, dtype=float)
window = torch.i0(beta * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta)
t *= math.pi
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t) kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel.mul_(window) kernel.mul_(window)
kernels.append(kernel) kernels.append(kernel)
...@@ -1403,6 +1416,8 @@ def resample( ...@@ -1403,6 +1416,8 @@ def resample(
new_freq: float, new_freq: float,
lowpass_filter_width: int = 6, lowpass_filter_width: int = 6,
rolloff: float = 0.99, rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
beta: Optional[float] = None,
) -> Tensor: ) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
...@@ -1421,6 +1436,9 @@ def resample( ...@@ -1421,6 +1436,9 @@ def resample(
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
resampling_method (str, optional): The resampling method.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
beta (float, optional): The shape parameter used for kaiser window.
Returns: Returns:
Tensor: The waveform at the new frequency of dimension (..., time). Tensor: The waveform at the new frequency of dimension (..., time).
...@@ -1433,6 +1451,7 @@ def resample( ...@@ -1433,6 +1451,7 @@ def resample(
gcd = math.gcd(int(orig_freq), int(new_freq)) gcd = math.gcd(int(orig_freq), int(new_freq))
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff) kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff,
resampling_method, beta)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width) resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled return resampled
...@@ -657,11 +657,13 @@ class Resample(torch.nn.Module): ...@@ -657,11 +657,13 @@ class Resample(torch.nn.Module):
Args: Args:
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``) orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``) new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``) resampling_method (str, optional): The resampling method.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
beta (float, optional): The shape parameter used for kaiser window.
""" """
def __init__(self, def __init__(self,
...@@ -669,7 +671,8 @@ class Resample(torch.nn.Module): ...@@ -669,7 +671,8 @@ class Resample(torch.nn.Module):
new_freq: float = 16000, new_freq: float = 16000,
resampling_method: str = 'sinc_interpolation', resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6, lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> None: rolloff: float = 0.99,
beta: Optional[float] = None) -> None:
super(Resample, self).__init__() super(Resample, self).__init__()
self.orig_freq = orig_freq self.orig_freq = orig_freq
...@@ -680,7 +683,8 @@ class Resample(torch.nn.Module): ...@@ -680,7 +683,8 @@ class Resample(torch.nn.Module):
self.rolloff = rolloff self.rolloff = rolloff
self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd, self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff) self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
...@@ -690,12 +694,9 @@ class Resample(torch.nn.Module): ...@@ -690,12 +694,9 @@ class Resample(torch.nn.Module):
Returns: Returns:
Tensor: Output signal of dimension (..., time). Tensor: Output signal of dimension (..., time).
""" """
if self.resampling_method == 'sinc_interpolation':
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd, return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width) self.kernel, self.width)
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
class ComplexNorm(torch.nn.Module): class ComplexNorm(torch.nn.Module):
r"""Compute the norm of complex tensor input. r"""Compute the norm of complex tensor input.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment