Unverified Commit c2740644 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Add rolloff param to resample (#1488)

parent 32f661f0
......@@ -755,7 +755,8 @@ def mfcc(
def resample_waveform(waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6) -> Tensor:
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> Tensor:
r"""Resamples the waveform at the new frequency.
This is a wrapper around ``torchaudio.functional.resample``.
......@@ -766,8 +767,10 @@ def resample_waveform(waveform: Tensor,
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Returns:
Tensor: The waveform at the new frequency
"""
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width)
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width, rolloff)
......@@ -1298,8 +1298,13 @@ def compute_kaldi_pitch(
return result
def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
device: torch.device, dtype: torch.dtype):
def _get_sinc_resample_kernel(
orig_freq: int,
new_freq: int,
lowpass_filter_width: int,
rolloff: float,
device: torch.device,
dtype: torch.dtype):
assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
......@@ -1307,7 +1312,7 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
base_freq *= 0.99
base_freq *= rolloff
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
......@@ -1352,7 +1357,8 @@ def resample(
waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
......@@ -1369,6 +1375,8 @@ def resample(
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
......@@ -1386,7 +1394,7 @@ def resample(
new_freq = new_freq // gcd
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
waveform.device, waveform.dtype)
rolloff, waveform.device, waveform.dtype)
num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
......
......@@ -640,16 +640,24 @@ class Resample(torch.nn.Module):
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
"""
def __init__(self,
orig_freq: int = 16000,
new_freq: int = 16000,
resampling_method: str = 'sinc_interpolation') -> None:
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> None:
super(Resample, self).__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.resampling_method = resampling_method
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff
def forward(self, waveform: Tensor) -> Tensor:
r"""
......@@ -660,7 +668,7 @@ class Resample(torch.nn.Module):
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
return F.resample(waveform, self.orig_freq, self.new_freq)
return F.resample(waveform, self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff)
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment