Unverified Commit 079b3f5d authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

[BC-Breaking] Avoid moving resampling kernel device and dtype moves (#1514)

Precomputing and caching the resampling kernel in transforms provides speed improvements for resample, but no longer handles the automatic device and dtype recognition and construction based on input waveform. This is BC-breaking if users do not manually move the transforms object to the correct device and dtype, in which case calls to resample will fail if the input waveform is on gpu, or not of float32 dtype. Precomputing the kernel additionally results in very minor precision differences from previous implementation.
parent 264ab15a
......@@ -1305,7 +1305,9 @@ def _get_sinc_resample_kernel(
lowpass_filter_width: int,
rolloff: float,
resampling_method: str,
beta: Optional[float]):
beta: Optional[float],
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None):
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
......@@ -1360,7 +1362,8 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, dtype=torch.float64)
idx_dtype = dtype if dtype is not None else torch.float64
idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype)
for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
......@@ -1379,7 +1382,10 @@ def _get_sinc_resample_kernel(
kernels.append(kernel)
scale = base_freq / orig_freq
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width
kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale)
if dtype is None:
kernels = kernels.to(dtype=torch.float32)
return kernels, width
def _apply_sinc_resample_kernel(
......@@ -1396,7 +1402,6 @@ def _apply_sinc_resample_kernel(
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
kernel = kernel.to(device=waveform.device, dtype=waveform.dtype)
num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
......@@ -1452,6 +1457,6 @@ def resample(
gcd = math.gcd(int(orig_freq), int(new_freq))
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff,
resampling_method, beta)
resampling_method, beta, waveform.device, waveform.dtype)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled
......@@ -673,6 +673,11 @@ class Resample(torch.nn.Module):
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
beta (float, optional): The shape parameter used for kaiser window.
Note: If resampling on waveforms of higher precision than float32, there may be a small loss of precision
because the kernel is cached once as float32. If high precision resampling is important for your application,
the functional form will retain higher precision, but run slower because it does not cache the kernel.
Alternatively, you could rewrite a transform that caches a higher precision kernel.
"""
def __init__(self,
......@@ -691,9 +696,10 @@ class Resample(torch.nn.Module):
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff
self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
self.register_buffer('kernel', kernel)
def forward(self, waveform: Tensor) -> Tensor:
r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment