Unverified Commit 52e7bfd9 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Precompute transforms.Resample kernel (#1499)

parent 8a86c463
......@@ -1299,12 +1299,28 @@ def compute_kaldi_pitch(
def _get_sinc_resample_kernel(
orig_freq: int,
new_freq: int,
orig_freq: float,
new_freq: float,
gcd: int,
lowpass_filter_width: int,
rolloff: float,
device: torch.device,
dtype: torch.dtype):
rolloff: float):
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd
assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
......@@ -1336,7 +1352,7 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)
idx = torch.arange(-width, width + orig_freq)
for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
......@@ -1353,6 +1369,34 @@ def _get_sinc_resample_kernel(
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width
def _apply_sinc_resample_kernel(
waveform: Tensor,
orig_freq: float,
new_freq: float,
gcd: int,
kernel: Tensor,
width: int,
):
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
kernel = kernel.to(device=waveform.device, dtype=waveform.dtype)
num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
resampled = resampled[..., :target_length]
# unpack batch
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
return resampled
def resample(
waveform: Tensor,
orig_freq: float,
......@@ -1380,42 +1424,15 @@ def resample(
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Note: ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
more efficient computation if resampling multiple waveforms with the same resampling parameters.
"""
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
assert orig_freq > 0.0 and new_freq > 0.0
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)
orig_freq = int(orig_freq)
new_freq = int(new_freq)
gcd = math.gcd(orig_freq, new_freq)
orig_freq = orig_freq // gcd
new_freq = new_freq // gcd
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
rolloff, waveform.device, waveform.dtype)
num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
resampled = resampled[..., :target_length]
gcd = math.gcd(int(orig_freq), int(new_freq))
# unpack batch
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled
......@@ -8,6 +8,10 @@ import torch
from torch import Tensor
from torchaudio import functional as F
from .functional.functional import (
_get_sinc_resample_kernel,
_apply_sinc_resample_kernel,
)
__all__ = [
'Spectrogram',
......@@ -661,18 +665,23 @@ class Resample(torch.nn.Module):
"""
def __init__(self,
orig_freq: int = 16000,
new_freq: int = 16000,
orig_freq: float = 16000,
new_freq: float = 16000,
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> None:
super(Resample, self).__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
self.resampling_method = resampling_method
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff
self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff)
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
......@@ -682,7 +691,8 @@ class Resample(torch.nn.Module):
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
return F.resample(waveform, self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff)
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment