Unverified Commit 52e7bfd9 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Precompute transforms.Resample kernel (#1499)

parent 8a86c463
...@@ -1299,12 +1299,28 @@ def compute_kaldi_pitch( ...@@ -1299,12 +1299,28 @@ def compute_kaldi_pitch(
def _get_sinc_resample_kernel( def _get_sinc_resample_kernel(
orig_freq: int, orig_freq: float,
new_freq: int, new_freq: float,
gcd: int,
lowpass_filter_width: int, lowpass_filter_width: int,
rolloff: float, rolloff: float):
device: torch.device,
dtype: torch.dtype): if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd
assert lowpass_filter_width > 0 assert lowpass_filter_width > 0
kernels = [] kernels = []
base_freq = min(orig_freq, new_freq) base_freq = min(orig_freq, new_freq)
...@@ -1336,7 +1352,7 @@ def _get_sinc_resample_kernel( ...@@ -1336,7 +1352,7 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right... # they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for # There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work. # future work.
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype) idx = torch.arange(-width, width + orig_freq)
for i in range(new_freq): for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq t = (-i / new_freq + idx / orig_freq) * base_freq
...@@ -1353,6 +1369,34 @@ def _get_sinc_resample_kernel( ...@@ -1353,6 +1369,34 @@ def _get_sinc_resample_kernel(
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width
def _apply_sinc_resample_kernel(
waveform: Tensor,
orig_freq: float,
new_freq: float,
gcd: int,
kernel: Tensor,
width: int,
):
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
kernel = kernel.to(device=waveform.device, dtype=waveform.dtype)
num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
resampled = resampled[..., :target_length]
# unpack batch
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
return resampled
def resample( def resample(
waveform: Tensor, waveform: Tensor,
orig_freq: float, orig_freq: float,
...@@ -1380,42 +1424,15 @@ def resample( ...@@ -1380,42 +1424,15 @@ def resample(
Returns: Returns:
Tensor: The waveform at the new frequency of dimension (..., time). Tensor: The waveform at the new frequency of dimension (..., time).
Note: ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
more efficient computation if resampling multiple waveforms with the same resampling parameters.
""" """
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
assert orig_freq > 0.0 and new_freq > 0.0 assert orig_freq > 0.0 and new_freq > 0.0
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): gcd = math.gcd(int(orig_freq), int(new_freq))
warnings.warn(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)
orig_freq = int(orig_freq)
new_freq = int(new_freq)
gcd = math.gcd(orig_freq, new_freq)
orig_freq = orig_freq // gcd
new_freq = new_freq // gcd
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
rolloff, waveform.device, waveform.dtype)
num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
resampled = resampled[..., :target_length]
# unpack batch kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff)
resampled = resampled.view(shape[:-1] + resampled.shape[-1:]) resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled return resampled
...@@ -8,6 +8,10 @@ import torch ...@@ -8,6 +8,10 @@ import torch
from torch import Tensor from torch import Tensor
from torchaudio import functional as F from torchaudio import functional as F
from .functional.functional import (
_get_sinc_resample_kernel,
_apply_sinc_resample_kernel,
)
__all__ = [ __all__ = [
'Spectrogram', 'Spectrogram',
...@@ -661,18 +665,23 @@ class Resample(torch.nn.Module): ...@@ -661,18 +665,23 @@ class Resample(torch.nn.Module):
""" """
def __init__(self, def __init__(self,
orig_freq: int = 16000, orig_freq: float = 16000,
new_freq: int = 16000, new_freq: float = 16000,
resampling_method: str = 'sinc_interpolation', resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6, lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> None: rolloff: float = 0.99) -> None:
super(Resample, self).__init__() super(Resample, self).__init__()
self.orig_freq = orig_freq self.orig_freq = orig_freq
self.new_freq = new_freq self.new_freq = new_freq
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
self.resampling_method = resampling_method self.resampling_method = resampling_method
self.lowpass_filter_width = lowpass_filter_width self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff self.rolloff = rolloff
self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff)
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
Args: Args:
...@@ -682,7 +691,8 @@ class Resample(torch.nn.Module): ...@@ -682,7 +691,8 @@ class Resample(torch.nn.Module):
Tensor: Output signal of dimension (..., time). Tensor: Output signal of dimension (..., time).
""" """
if self.resampling_method == 'sinc_interpolation': if self.resampling_method == 'sinc_interpolation':
return F.resample(waveform, self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff) return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method)) raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment