Unverified Commit aec0e8c9 authored by moto's avatar moto Committed by GitHub
Browse files

Add dtype argument for kernel caching precision (#1556)

Since 0.9.0-RC1, `T.Resample` precomputes and caches resampling kernel for performance improvement. (10x improvement).

The implementation from 0.8.0 computed the kernel on-the-fly on the same `device`/`dtype` as the input Tensor, 
but in the newer version, the kernel is precomputed at the construction time and is cached with `float32` first.
This causes degradation if one wants to perform resampling on `float64`, because `sinc` values computed on `float32`s are not good enough for resampling in `float64`.

The reason why we decided to use `float32` for initial caching is to keep the UX disruption minimum, and there were no way to make it work for `float64`. This PR adds `dtype` argument, that can be used for overwriting the cache precision.
parent e9415df4
import itertools
import warnings import warnings
import torch import torch
...@@ -8,8 +7,8 @@ from torchaudio_unittest.common_utils import ( ...@@ -8,8 +7,8 @@ from torchaudio_unittest.common_utils import (
TestBaseMixin, TestBaseMixin,
get_whitenoise, get_whitenoise,
get_spectrogram, get_spectrogram,
nested_params,
) )
from parameterized import parameterized
def _get_ratio(mat): def _get_ratio(mat):
...@@ -80,13 +79,24 @@ class TransformsTestBase(TestBaseMixin): ...@@ -80,13 +79,24 @@ class TransformsTestBase(TestBaseMixin):
T.MelScale(n_mels=64, sample_rate=8000, n_stft=201) T.MelScale(n_mels=64, sample_rate=8000, n_stft=201)
assert len(caught_warnings) == 0 assert len(caught_warnings) == 0
@parameterized.expand(list(itertools.product( @nested_params(
["sinc_interpolation", "kaiser_window"], ["sinc_interpolation", "kaiser_window"],
[16000, 44100], [16000, 44100],
))) )
def test_resample_identity(self, resampling_method, sample_rate): def test_resample_identity(self, resampling_method, sample_rate):
"""When sampling rate is not changed, the transform returns an identical Tensor"""
waveform = get_whitenoise(sample_rate=sample_rate, duration=1) waveform = get_whitenoise(sample_rate=sample_rate, duration=1)
resampler = T.Resample(sample_rate, sample_rate) resampler = T.Resample(sample_rate, sample_rate, resampling_method)
resampled = resampler(waveform) resampled = resampler(waveform)
self.assertEqual(waveform, resampled) self.assertEqual(waveform, resampled)
@nested_params(
["sinc_interpolation", "kaiser_window"],
[None, torch.float64],
)
def test_resample_cache_dtype(self, resampling_method, dtype):
"""Providing dtype changes the kernel cache dtype"""
transform = T.Resample(16000, 44100, resampling_method, dtype=dtype)
assert transform.kernel.dtype == dtype if dtype is not None else torch.float32
...@@ -664,16 +664,27 @@ class Resample(torch.nn.Module): ...@@ -664,16 +664,27 @@ class Resample(torch.nn.Module):
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
beta (float or None): The shape parameter used for kaiser window. beta (float or None): The shape parameter used for kaiser window.
dtype (torch.device, optional):
Determnines the precision that resampling kernel is pre-computed and cached. If not provided,
kernel is computed with ``torch.float64`` then cached as ``torch.float32``.
If you need higher precision, provide ``torch.float64``, and the pre-computed kernel is computed and
cached as ``torch.float64``. If you use resample with lower precision, then instead of providing this
providing this argument, please use ``Resample.to(dtype)``, so that the kernel generation is still
carried out on ``torch.float64``.
""" """
def __init__(self, def __init__(
orig_freq: float = 16000, self,
new_freq: float = 16000, orig_freq: float = 16000,
resampling_method: str = 'sinc_interpolation', new_freq: float = 16000,
lowpass_filter_width: int = 6, resampling_method: str = 'sinc_interpolation',
rolloff: float = 0.99, lowpass_filter_width: int = 6,
beta: Optional[float] = None) -> None: rolloff: float = 0.99,
super(Resample, self).__init__() beta: Optional[float] = None,
*,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
self.orig_freq = orig_freq self.orig_freq = orig_freq
self.new_freq = new_freq self.new_freq = new_freq
...@@ -681,11 +692,13 @@ class Resample(torch.nn.Module): ...@@ -681,11 +692,13 @@ class Resample(torch.nn.Module):
self.resampling_method = resampling_method self.resampling_method = resampling_method
self.lowpass_filter_width = lowpass_filter_width self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff self.rolloff = rolloff
self.beta = beta
if self.orig_freq != self.new_freq: if self.orig_freq != self.new_freq:
kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd, kernel, self.width = _get_sinc_resample_kernel(
self.lowpass_filter_width, self.rolloff, self.orig_freq, self.new_freq, self.gcd,
self.resampling_method, beta) self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta, dtype=dtype)
self.register_buffer('kernel', kernel) self.register_buffer('kernel', kernel)
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
...@@ -698,8 +711,9 @@ class Resample(torch.nn.Module): ...@@ -698,8 +711,9 @@ class Resample(torch.nn.Module):
""" """
if self.orig_freq == self.new_freq: if self.orig_freq == self.new_freq:
return waveform return waveform
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd, return _apply_sinc_resample_kernel(
self.kernel, self.width) waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)
class ComplexNorm(torch.nn.Module): class ComplexNorm(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment