Unverified Commit aec0e8c9 authored by moto's avatar moto Committed by GitHub
Browse files

Add dtype argument for kernel caching precision (#1556)

Since 0.9.0-RC1, `T.Resample` precomputes and caches resampling kernel for performance improvement. (10x improvement).

The implementation from 0.8.0 computed the kernel on-the-fly on the same `device`/`dtype` as the input Tensor, 
but in the newer version, the kernel is precomputed at the construction time and is cached with `float32` first.
This causes degradation if one wants to perform resampling on `float64`, because `sinc` values computed on `float32`s are not good enough for resampling in `float64`.

The reason why we decided to use `float32` for initial caching is to keep the UX disruption minimum, and there were no way to make it work for `float64`. This PR adds `dtype` argument, that can be used for overwriting the cache precision.
parent e9415df4
import itertools
import warnings
import torch
......@@ -8,8 +7,8 @@ from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
get_spectrogram,
nested_params,
)
from parameterized import parameterized
def _get_ratio(mat):
......@@ -80,13 +79,24 @@ class TransformsTestBase(TestBaseMixin):
T.MelScale(n_mels=64, sample_rate=8000, n_stft=201)
assert len(caught_warnings) == 0
@parameterized.expand(list(itertools.product(
@nested_params(
["sinc_interpolation", "kaiser_window"],
[16000, 44100],
)))
)
def test_resample_identity(self, resampling_method, sample_rate):
"""When sampling rate is not changed, the transform returns an identical Tensor"""
waveform = get_whitenoise(sample_rate=sample_rate, duration=1)
resampler = T.Resample(sample_rate, sample_rate)
resampler = T.Resample(sample_rate, sample_rate, resampling_method)
resampled = resampler(waveform)
self.assertEqual(waveform, resampled)
@nested_params(
["sinc_interpolation", "kaiser_window"],
[None, torch.float64],
)
def test_resample_cache_dtype(self, resampling_method, dtype):
"""Providing dtype changes the kernel cache dtype"""
transform = T.Resample(16000, 44100, resampling_method, dtype=dtype)
assert transform.kernel.dtype == dtype if dtype is not None else torch.float32
......@@ -664,16 +664,27 @@ class Resample(torch.nn.Module):
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
beta (float or None): The shape parameter used for kaiser window.
dtype (torch.device, optional):
Determnines the precision that resampling kernel is pre-computed and cached. If not provided,
kernel is computed with ``torch.float64`` then cached as ``torch.float32``.
If you need higher precision, provide ``torch.float64``, and the pre-computed kernel is computed and
cached as ``torch.float64``. If you use resample with lower precision, then instead of providing this
providing this argument, please use ``Resample.to(dtype)``, so that the kernel generation is still
carried out on ``torch.float64``.
"""
def __init__(self,
orig_freq: float = 16000,
new_freq: float = 16000,
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
beta: Optional[float] = None) -> None:
super(Resample, self).__init__()
def __init__(
self,
orig_freq: float = 16000,
new_freq: float = 16000,
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
beta: Optional[float] = None,
*,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
......@@ -681,11 +692,13 @@ class Resample(torch.nn.Module):
self.resampling_method = resampling_method
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff
self.beta = beta
if self.orig_freq != self.new_freq:
kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
kernel, self.width = _get_sinc_resample_kernel(
self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta, dtype=dtype)
self.register_buffer('kernel', kernel)
def forward(self, waveform: Tensor) -> Tensor:
......@@ -698,8 +711,9 @@ class Resample(torch.nn.Module):
"""
if self.orig_freq == self.new_freq:
return waveform
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)
return _apply_sinc_resample_kernel(
waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)
class ComplexNorm(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment