Unverified Commit c2740644 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Add rolloff param to resample (#1488)

parent 32f661f0
...@@ -755,7 +755,8 @@ def mfcc( ...@@ -755,7 +755,8 @@ def mfcc(
def resample_waveform(waveform: Tensor, def resample_waveform(waveform: Tensor,
orig_freq: float, orig_freq: float,
new_freq: float, new_freq: float,
lowpass_filter_width: int = 6) -> Tensor: lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> Tensor:
r"""Resamples the waveform at the new frequency. r"""Resamples the waveform at the new frequency.
This is a wrapper around ``torchaudio.functional.resample``. This is a wrapper around ``torchaudio.functional.resample``.
...@@ -766,8 +767,10 @@ def resample_waveform(waveform: Tensor, ...@@ -766,8 +767,10 @@ def resample_waveform(waveform: Tensor,
new_freq (float): The desired frequency new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Returns: Returns:
Tensor: The waveform at the new frequency Tensor: The waveform at the new frequency
""" """
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width) return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width, rolloff)
...@@ -1298,8 +1298,13 @@ def compute_kaldi_pitch( ...@@ -1298,8 +1298,13 @@ def compute_kaldi_pitch(
return result return result
def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int, def _get_sinc_resample_kernel(
device: torch.device, dtype: torch.dtype): orig_freq: int,
new_freq: int,
lowpass_filter_width: int,
rolloff: float,
device: torch.device,
dtype: torch.dtype):
assert lowpass_filter_width > 0 assert lowpass_filter_width > 0
kernels = [] kernels = []
base_freq = min(orig_freq, new_freq) base_freq = min(orig_freq, new_freq)
...@@ -1307,7 +1312,7 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt ...@@ -1307,7 +1312,7 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt
# At first I thought I only needed this when downsampling, but when upsampling # At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding, # you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts. # which will add high freq artifacts.
base_freq *= 0.99 base_freq *= rolloff
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula: # using the sinc interpolation formula:
...@@ -1352,7 +1357,8 @@ def resample( ...@@ -1352,7 +1357,8 @@ def resample(
waveform: Tensor, waveform: Tensor,
orig_freq: float, orig_freq: float,
new_freq: float, new_freq: float,
lowpass_filter_width: int = 6 lowpass_filter_width: int = 6,
rolloff: float = 0.99,
) -> Tensor: ) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
...@@ -1369,6 +1375,8 @@ def resample( ...@@ -1369,6 +1375,8 @@ def resample(
new_freq (float): The desired frequency new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Returns: Returns:
Tensor: The waveform at the new frequency of dimension (..., time). Tensor: The waveform at the new frequency of dimension (..., time).
...@@ -1386,7 +1394,7 @@ def resample( ...@@ -1386,7 +1394,7 @@ def resample(
new_freq = new_freq // gcd new_freq = new_freq // gcd
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width, kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
waveform.device, waveform.dtype) rolloff, waveform.device, waveform.dtype)
num_wavs, length = waveform.shape num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq)) waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
......
...@@ -640,16 +640,24 @@ class Resample(torch.nn.Module): ...@@ -640,16 +640,24 @@ class Resample(torch.nn.Module):
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``) orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``) new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``) resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
""" """
def __init__(self, def __init__(self,
orig_freq: int = 16000, orig_freq: int = 16000,
new_freq: int = 16000, new_freq: int = 16000,
resampling_method: str = 'sinc_interpolation') -> None: resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> None:
super(Resample, self).__init__() super(Resample, self).__init__()
self.orig_freq = orig_freq self.orig_freq = orig_freq
self.new_freq = new_freq self.new_freq = new_freq
self.resampling_method = resampling_method self.resampling_method = resampling_method
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
...@@ -660,7 +668,7 @@ class Resample(torch.nn.Module): ...@@ -660,7 +668,7 @@ class Resample(torch.nn.Module):
Tensor: Output signal of dimension (..., time). Tensor: Output signal of dimension (..., time).
""" """
if self.resampling_method == 'sinc_interpolation': if self.resampling_method == 'sinc_interpolation':
return F.resample(waveform, self.orig_freq, self.new_freq) return F.resample(waveform, self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff)
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method)) raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment