Commit e6bebe6a authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Rename resampling_method options (#2922)

Summary:
resolves https://github.com/pytorch/audio/issues/2891

Rename `resampling_method` options to more accurately describe what is happening. Previously the methods were set to `sinc_interpolation` and `kaiser_window`, which can be confusing as both options actually use sinc interpolation methodology, but differ in the window function used. As a result, rename `sinc_interpolation` to `sinc_interp_hann` and `kaiser_window` to `sinc_interp_kaiser`. Using an old option will throw a warning, and those options will be deprecated in 2 released. The numerical behavior is unchanged.

Pull Request resolved: https://github.com/pytorch/audio/pull/2922

Reviewed By: mthrok

Differential Revision: D42083619

Pulled By: carolineechen

fbshipit-source-id: 9a9a7ea2d2daeadc02d53dddfd26afe249459e70
parent 3cf266a3
...@@ -240,13 +240,13 @@ plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8") ...@@ -240,13 +240,13 @@ plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")
sample_rate = 48000 sample_rate = 48000
resample_rate = 32000 resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation") resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_hann")
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default") plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")
###################################################################### ######################################################################
# #
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window") resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_kaiser")
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default") plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
...@@ -271,7 +271,7 @@ resampled_waveform = F.resample( ...@@ -271,7 +271,7 @@ resampled_waveform = F.resample(
resample_rate, resample_rate,
lowpass_filter_width=64, lowpass_filter_width=64,
rolloff=0.9475937167399596, rolloff=0.9475937167399596,
resampling_method="kaiser_window", resampling_method="sinc_interp_kaiser",
beta=14.769656459379492, beta=14.769656459379492,
) )
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)") plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")
...@@ -300,7 +300,7 @@ resampled_waveform = F.resample( ...@@ -300,7 +300,7 @@ resampled_waveform = F.resample(
resample_rate, resample_rate,
lowpass_filter_width=16, lowpass_filter_width=16,
rolloff=0.85, rolloff=0.85,
resampling_method="kaiser_window", resampling_method="sinc_interp_kaiser",
beta=8.555504641634386, beta=8.555504641634386,
) )
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)") plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")
...@@ -344,7 +344,7 @@ def benchmark_resample_functional( ...@@ -344,7 +344,7 @@ def benchmark_resample_functional(
resample_rate, resample_rate,
lowpass_filter_width=6, lowpass_filter_width=6,
rolloff=0.99, rolloff=0.99,
resampling_method="sinc_interpolation", resampling_method="sinc_interp_hann",
beta=None, beta=None,
iters=5, iters=5,
): ):
...@@ -375,7 +375,7 @@ def benchmark_resample_transforms( ...@@ -375,7 +375,7 @@ def benchmark_resample_transforms(
resample_rate, resample_rate,
lowpass_filter_width=6, lowpass_filter_width=6,
rolloff=0.99, rolloff=0.99,
resampling_method="sinc_interpolation", resampling_method="sinc_interp_hann",
beta=None, beta=None,
iters=5, iters=5,
): ):
...@@ -451,7 +451,7 @@ def benchmark(sample_rate, resample_rate): ...@@ -451,7 +451,7 @@ def benchmark(sample_rate, resample_rate):
kwargs = { kwargs = {
"lowpass_filter_width": 64, "lowpass_filter_width": 64,
"rolloff": 0.9475937167399596, "rolloff": 0.9475937167399596,
"resampling_method": "kaiser_window", "resampling_method": "sinc_interp_kaiser",
"beta": 14.769656459379492, "beta": 14.769656459379492,
} }
lib_time = benchmark_resample_librosa(*args, res_type="kaiser_best") lib_time = benchmark_resample_librosa(*args, res_type="kaiser_best")
...@@ -464,7 +464,7 @@ def benchmark(sample_rate, resample_rate): ...@@ -464,7 +464,7 @@ def benchmark(sample_rate, resample_rate):
kwargs = { kwargs = {
"lowpass_filter_width": 16, "lowpass_filter_width": 16,
"rolloff": 0.85, "rolloff": 0.85,
"resampling_method": "kaiser_window", "resampling_method": "sinc_interp_kaiser",
"beta": 8.555504641634386, "beta": 8.555504641634386,
} }
lib_time = benchmark_resample_librosa(*args, res_type="kaiser_fast") lib_time = benchmark_resample_librosa(*args, res_type="kaiser_fast")
...@@ -531,8 +531,8 @@ plot(df) ...@@ -531,8 +531,8 @@ plot(df)
# - a larger ``lowpass_filter_width`` results in a larger resampling kernel, # - a larger ``lowpass_filter_width`` results in a larger resampling kernel,
# and therefore increases computation time for both the kernel computation # and therefore increases computation time for both the kernel computation
# and convolution # and convolution
# - using ``kaiser_window`` results in longer computation times than the default # - using ``sinc_interp_kaiser`` results in longer computation times than the default
# ``sinc_interpolation`` because it is more complex to compute the intermediate # ``sinc_interp_hann`` because it is more complex to compute the intermediate
# window values # window values
# - a large GCD between the sample and resample rate will result # - a large GCD between the sample and resample rate will result
# in a simplification that allows for a smaller kernel and faster kernel computation. # in a simplification that allows for a smaller kernel and faster kernel computation.
......
...@@ -233,7 +233,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -233,7 +233,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
func = partial(F.sliding_window_cmn, **kwargs) func = partial(F.sliding_window_cmn, **kwargs)
self.assert_batch_consistency(func, inputs=(spectrogram,)) self.assert_batch_consistency(func, inputs=(spectrogram,))
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform(self, resampling_method): def test_resample_waveform(self, resampling_method):
num_channels = 3 num_channels = 3
sr = 16000 sr = 16000
......
...@@ -20,7 +20,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -20,7 +20,7 @@ from torchaudio_unittest.common_utils import (
class Functional(TestBaseMixin): class Functional(TestBaseMixin):
def _test_resample_waveform_accuracy( def _test_resample_waveform_accuracy(
self, up_scale_factor=None, down_scale_factor=None, resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4 self, up_scale_factor=None, down_scale_factor=None, resampling_method="sinc_interp_hann", atol=1e-1, rtol=1e-4
): ):
# resample the signal and compare it to the ground truth # resample the signal and compare it to the ground truth
n_to_trim = 20 n_to_trim = 20
...@@ -471,7 +471,7 @@ class Functional(TestBaseMixin): ...@@ -471,7 +471,7 @@ class Functional(TestBaseMixin):
@parameterized.expand( @parameterized.expand(
list( list(
itertools.product( itertools.product(
["sinc_interpolation", "kaiser_window"], ["sinc_interp_hann", "sinc_interp_kaiser"],
[16000, 44100], [16000, 44100],
) )
) )
...@@ -482,7 +482,7 @@ class Functional(TestBaseMixin): ...@@ -482,7 +482,7 @@ class Functional(TestBaseMixin):
resampled = F.resample(waveform, sample_rate, sample_rate) resampled = F.resample(waveform, sample_rate, sample_rate)
self.assertEqual(waveform, resampled) self.assertEqual(waveform, resampled)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform_upsample_size(self, resampling_method): def test_resample_waveform_upsample_size(self, resampling_method):
sr = 16000 sr = 16000
waveform = get_whitenoise( waveform = get_whitenoise(
...@@ -492,7 +492,7 @@ class Functional(TestBaseMixin): ...@@ -492,7 +492,7 @@ class Functional(TestBaseMixin):
upsampled = F.resample(waveform, sr, sr * 2, resampling_method=resampling_method) upsampled = F.resample(waveform, sr, sr * 2, resampling_method=resampling_method)
assert upsampled.size(-1) == waveform.size(-1) * 2 assert upsampled.size(-1) == waveform.size(-1) * 2
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform_downsample_size(self, resampling_method): def test_resample_waveform_downsample_size(self, resampling_method):
sr = 16000 sr = 16000
waveform = get_whitenoise( waveform = get_whitenoise(
...@@ -502,7 +502,7 @@ class Functional(TestBaseMixin): ...@@ -502,7 +502,7 @@ class Functional(TestBaseMixin):
downsampled = F.resample(waveform, sr, sr // 2, resampling_method=resampling_method) downsampled = F.resample(waveform, sr, sr // 2, resampling_method=resampling_method)
assert downsampled.size(-1) == waveform.size(-1) // 2 assert downsampled.size(-1) == waveform.size(-1) // 2
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform_identity_size(self, resampling_method): def test_resample_waveform_identity_size(self, resampling_method):
sr = 16000 sr = 16000
waveform = get_whitenoise( waveform = get_whitenoise(
...@@ -515,7 +515,7 @@ class Functional(TestBaseMixin): ...@@ -515,7 +515,7 @@ class Functional(TestBaseMixin):
@parameterized.expand( @parameterized.expand(
list( list(
itertools.product( itertools.product(
["sinc_interpolation", "kaiser_window"], ["sinc_interp_hann", "sinc_interp_kaiser"],
list(range(1, 20)), list(range(1, 20)),
) )
) )
...@@ -526,7 +526,7 @@ class Functional(TestBaseMixin): ...@@ -526,7 +526,7 @@ class Functional(TestBaseMixin):
@parameterized.expand( @parameterized.expand(
list( list(
itertools.product( itertools.product(
["sinc_interpolation", "kaiser_window"], ["sinc_interp_hann", "sinc_interp_kaiser"],
list(range(1, 20)), list(range(1, 20)),
) )
) )
......
...@@ -600,7 +600,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -600,7 +600,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_resample_sinc(self): def test_resample_sinc(self):
def func(tensor): def func(tensor):
sr1, sr2 = 16000, 8000 sr1, sr2 = 16000, 8000
return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation") return F.resample(tensor, sr1, sr2, resampling_method="sinc_interp_hann")
tensor = common_utils.get_whitenoise(sample_rate=16000) tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, (tensor,)) self._assert_consistency(func, (tensor,))
...@@ -616,7 +616,9 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -616,7 +616,9 @@ class Functional(TempDirMixin, TestBaseMixin):
sr1, sr2 = 16000, 8000 sr1, sr2 = 16000, 8000
lowpass_filter_width = 6 lowpass_filter_width = 6
rolloff = 0.99 rolloff = 0.99
self._assert_consistency(F.resample, (tensor, sr1, sr2, lowpass_filter_width, rolloff, "kaiser_window", beta)) self._assert_consistency(
F.resample, (tensor, sr1, sr2, lowpass_filter_width, rolloff, "sinc_interp_kaiser", beta)
)
def test_phase_vocoder(self): def test_phase_vocoder(self):
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2)) tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
......
...@@ -237,7 +237,7 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -237,7 +237,7 @@ class Tester(common_utils.TorchaudioTestCase):
torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method=invalid_resampling_method) torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method=invalid_resampling_method)
upsample_resample = torchaudio.transforms.Resample( upsample_resample = torchaudio.transforms.Resample(
sample_rate, upsample_rate, resampling_method="sinc_interpolation" sample_rate, upsample_rate, resampling_method="sinc_interp_hann"
) )
up_sampled = upsample_resample(waveform) up_sampled = upsample_resample(waveform)
...@@ -245,7 +245,7 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -245,7 +245,7 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2) self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
downsample_resample = torchaudio.transforms.Resample( downsample_resample = torchaudio.transforms.Resample(
sample_rate, downsample_rate, resampling_method="sinc_interpolation" sample_rate, downsample_rate, resampling_method="sinc_interp_hann"
) )
down_sampled = downsample_resample(waveform) down_sampled = downsample_resample(waveform)
......
...@@ -53,7 +53,7 @@ class TransformsTestBase(TestBaseMixin): ...@@ -53,7 +53,7 @@ class TransformsTestBase(TestBaseMixin):
assert _get_ratio(relative_diff < 1e-5) > 1e-5 assert _get_ratio(relative_diff < 1e-5) > 1e-5
@nested_params( @nested_params(
["sinc_interpolation", "kaiser_window"], ["sinc_interp_hann", "sinc_interp_kaiser"],
[16000, 44100], [16000, 44100],
) )
def test_resample_identity(self, resampling_method, sample_rate): def test_resample_identity(self, resampling_method, sample_rate):
...@@ -65,7 +65,7 @@ class TransformsTestBase(TestBaseMixin): ...@@ -65,7 +65,7 @@ class TransformsTestBase(TestBaseMixin):
self.assertEqual(waveform, resampled) self.assertEqual(waveform, resampled)
@nested_params( @nested_params(
["sinc_interpolation", "kaiser_window"], ["sinc_interp_hann", "sinc_interp_kaiser"],
[None, torch.float64], [None, torch.float64],
) )
def test_resample_cache_dtype(self, resampling_method, dtype): def test_resample_cache_dtype(self, resampling_method, dtype):
......
...@@ -1429,7 +1429,7 @@ def _get_sinc_resample_kernel( ...@@ -1429,7 +1429,7 @@ def _get_sinc_resample_kernel(
gcd: int, gcd: int,
lowpass_filter_width: int = 6, lowpass_filter_width: int = 6,
rolloff: float = 0.99, rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation", resampling_method: str = "sinc_interp_hann",
beta: Optional[float] = None, beta: Optional[float] = None,
device: torch.device = torch.device("cpu"), device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
...@@ -1445,7 +1445,17 @@ def _get_sinc_resample_kernel( ...@@ -1445,7 +1445,17 @@ def _get_sinc_resample_kernel(
"For more information, please refer to https://github.com/pytorch/audio/issues/1487." "For more information, please refer to https://github.com/pytorch/audio/issues/1487."
) )
if resampling_method not in ["sinc_interpolation", "kaiser_window"]: if resampling_method in ["sinc_interpolation", "kaiser_window"]:
method_map = {
"sinc_interpolation": "sinc_interp_hann",
"kaiser_window": "sinc_interp_kaiser",
}
warnings.warn(
f'"{resampling_method}" resampling method name is being deprecated and replaced by '
f'"{method_map[resampling_method]}" in the next release. '
"The default behavior remains unchanged."
)
elif resampling_method not in ["sinc_interp_hann", "sinc_interp_kaiser"]:
raise ValueError("Invalid resampling method: {}".format(resampling_method)) raise ValueError("Invalid resampling method: {}".format(resampling_method))
orig_freq = int(orig_freq) // gcd orig_freq = int(orig_freq) // gcd
...@@ -1492,10 +1502,10 @@ def _get_sinc_resample_kernel( ...@@ -1492,10 +1502,10 @@ def _get_sinc_resample_kernel(
# we do not use built in torch windows here as we need to evaluate the window # we do not use built in torch windows here as we need to evaluate the window
# at specific positions, not over a regular grid. # at specific positions, not over a regular grid.
if resampling_method == "sinc_interpolation": if resampling_method == "sinc_interp_hann":
window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2 window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
else: else:
# kaiser_window # sinc_interp_kaiser
if beta is None: if beta is None:
beta = 14.769656459379492 beta = 14.769656459379492
beta_tensor = torch.tensor(float(beta)) beta_tensor = torch.tensor(float(beta))
...@@ -1549,7 +1559,7 @@ def resample( ...@@ -1549,7 +1559,7 @@ def resample(
new_freq: int, new_freq: int,
lowpass_filter_width: int = 6, lowpass_filter_width: int = 6,
rolloff: float = 0.99, rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation", resampling_method: str = "sinc_interp_hann",
beta: Optional[float] = None, beta: Optional[float] = None,
) -> Tensor: ) -> Tensor:
r"""Resamples the waveform at the new frequency using bandlimited interpolation. :cite:`RESAMPLE`. r"""Resamples the waveform at the new frequency using bandlimited interpolation. :cite:`RESAMPLE`.
...@@ -1571,7 +1581,7 @@ def resample( ...@@ -1571,7 +1581,7 @@ def resample(
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
resampling_method (str, optional): The resampling method to use. resampling_method (str, optional): The resampling method to use.
Options: [``"sinc_interpolation"``, ``"kaiser_window"``] (Default: ``"sinc_interpolation"``) Options: [``"sinc_interp_hann"``, ``"sinc_interp_kaiser"``] (Default: ``"sinc_interp_hann"``)
beta (float or None, optional): The shape parameter used for kaiser window. beta (float or None, optional): The shape parameter used for kaiser window.
Returns: Returns:
......
...@@ -942,7 +942,7 @@ class Resample(torch.nn.Module): ...@@ -942,7 +942,7 @@ class Resample(torch.nn.Module):
orig_freq (int, optional): The original frequency of the signal. (Default: ``16000``) orig_freq (int, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (int, optional): The desired frequency. (Default: ``16000``) new_freq (int, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method to use. resampling_method (str, optional): The resampling method to use.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``"sinc_interpolation"``) Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. (Default: ``6``) but less efficient. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
...@@ -966,7 +966,7 @@ class Resample(torch.nn.Module): ...@@ -966,7 +966,7 @@ class Resample(torch.nn.Module):
self, self,
orig_freq: int = 16000, orig_freq: int = 16000,
new_freq: int = 16000, new_freq: int = 16000,
resampling_method: str = "sinc_interpolation", resampling_method: str = "sinc_interp_hann",
lowpass_filter_width: int = 6, lowpass_filter_width: int = 6,
rolloff: float = 0.99, rolloff: float = 0.99,
beta: Optional[float] = None, beta: Optional[float] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment