Commit e6bebe6a authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Rename resampling_method options (#2922)

Summary:
resolves https://github.com/pytorch/audio/issues/2891

Rename `resampling_method` options to more accurately describe what is happening. Previously the methods were set to `sinc_interpolation` and `kaiser_window`, which can be confusing as both options actually use sinc interpolation methodology, but differ in the window function used. As a result, rename `sinc_interpolation` to `sinc_interp_hann` and `kaiser_window` to `sinc_interp_kaiser`. Using an old option will throw a warning, and those options will be deprecated in 2 released. The numerical behavior is unchanged.

Pull Request resolved: https://github.com/pytorch/audio/pull/2922

Reviewed By: mthrok

Differential Revision: D42083619

Pulled By: carolineechen

fbshipit-source-id: 9a9a7ea2d2daeadc02d53dddfd26afe249459e70
parent 3cf266a3
......@@ -240,13 +240,13 @@ plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")
sample_rate = 48000
resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation")
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_hann")
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")
######################################################################
#
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window")
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_kaiser")
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
......@@ -271,7 +271,7 @@ resampled_waveform = F.resample(
resample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
resampling_method="sinc_interp_kaiser",
beta=14.769656459379492,
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")
......@@ -300,7 +300,7 @@ resampled_waveform = F.resample(
resample_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
resampling_method="sinc_interp_kaiser",
beta=8.555504641634386,
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")
......@@ -344,7 +344,7 @@ def benchmark_resample_functional(
resample_rate,
lowpass_filter_width=6,
rolloff=0.99,
resampling_method="sinc_interpolation",
resampling_method="sinc_interp_hann",
beta=None,
iters=5,
):
......@@ -375,7 +375,7 @@ def benchmark_resample_transforms(
resample_rate,
lowpass_filter_width=6,
rolloff=0.99,
resampling_method="sinc_interpolation",
resampling_method="sinc_interp_hann",
beta=None,
iters=5,
):
......@@ -451,7 +451,7 @@ def benchmark(sample_rate, resample_rate):
kwargs = {
"lowpass_filter_width": 64,
"rolloff": 0.9475937167399596,
"resampling_method": "kaiser_window",
"resampling_method": "sinc_interp_kaiser",
"beta": 14.769656459379492,
}
lib_time = benchmark_resample_librosa(*args, res_type="kaiser_best")
......@@ -464,7 +464,7 @@ def benchmark(sample_rate, resample_rate):
kwargs = {
"lowpass_filter_width": 16,
"rolloff": 0.85,
"resampling_method": "kaiser_window",
"resampling_method": "sinc_interp_kaiser",
"beta": 8.555504641634386,
}
lib_time = benchmark_resample_librosa(*args, res_type="kaiser_fast")
......@@ -531,8 +531,8 @@ plot(df)
# - a larger ``lowpass_filter_width`` results in a larger resampling kernel,
# and therefore increases computation time for both the kernel computation
# and convolution
# - using ``kaiser_window`` results in longer computation times than the default
# ``sinc_interpolation`` because it is more complex to compute the intermediate
# - using ``sinc_interp_kaiser`` results in longer computation times than the default
# ``sinc_interp_hann`` because it is more complex to compute the intermediate
# window values
# - a large GCD between the sample and resample rate will result
# in a simplification that allows for a smaller kernel and faster kernel computation.
......
......@@ -233,7 +233,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
func = partial(F.sliding_window_cmn, **kwargs)
self.assert_batch_consistency(func, inputs=(spectrogram,))
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
@parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform(self, resampling_method):
num_channels = 3
sr = 16000
......
......@@ -20,7 +20,7 @@ from torchaudio_unittest.common_utils import (
class Functional(TestBaseMixin):
def _test_resample_waveform_accuracy(
self, up_scale_factor=None, down_scale_factor=None, resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4
self, up_scale_factor=None, down_scale_factor=None, resampling_method="sinc_interp_hann", atol=1e-1, rtol=1e-4
):
# resample the signal and compare it to the ground truth
n_to_trim = 20
......@@ -471,7 +471,7 @@ class Functional(TestBaseMixin):
@parameterized.expand(
list(
itertools.product(
["sinc_interpolation", "kaiser_window"],
["sinc_interp_hann", "sinc_interp_kaiser"],
[16000, 44100],
)
)
......@@ -482,7 +482,7 @@ class Functional(TestBaseMixin):
resampled = F.resample(waveform, sample_rate, sample_rate)
self.assertEqual(waveform, resampled)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
@parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform_upsample_size(self, resampling_method):
sr = 16000
waveform = get_whitenoise(
......@@ -492,7 +492,7 @@ class Functional(TestBaseMixin):
upsampled = F.resample(waveform, sr, sr * 2, resampling_method=resampling_method)
assert upsampled.size(-1) == waveform.size(-1) * 2
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
@parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform_downsample_size(self, resampling_method):
sr = 16000
waveform = get_whitenoise(
......@@ -502,7 +502,7 @@ class Functional(TestBaseMixin):
downsampled = F.resample(waveform, sr, sr // 2, resampling_method=resampling_method)
assert downsampled.size(-1) == waveform.size(-1) // 2
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
@parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform_identity_size(self, resampling_method):
sr = 16000
waveform = get_whitenoise(
......@@ -515,7 +515,7 @@ class Functional(TestBaseMixin):
@parameterized.expand(
list(
itertools.product(
["sinc_interpolation", "kaiser_window"],
["sinc_interp_hann", "sinc_interp_kaiser"],
list(range(1, 20)),
)
)
......@@ -526,7 +526,7 @@ class Functional(TestBaseMixin):
@parameterized.expand(
list(
itertools.product(
["sinc_interpolation", "kaiser_window"],
["sinc_interp_hann", "sinc_interp_kaiser"],
list(range(1, 20)),
)
)
......
......@@ -600,7 +600,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_resample_sinc(self):
def func(tensor):
sr1, sr2 = 16000, 8000
return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation")
return F.resample(tensor, sr1, sr2, resampling_method="sinc_interp_hann")
tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, (tensor,))
......@@ -616,7 +616,9 @@ class Functional(TempDirMixin, TestBaseMixin):
sr1, sr2 = 16000, 8000
lowpass_filter_width = 6
rolloff = 0.99
self._assert_consistency(F.resample, (tensor, sr1, sr2, lowpass_filter_width, rolloff, "kaiser_window", beta))
self._assert_consistency(
F.resample, (tensor, sr1, sr2, lowpass_filter_width, rolloff, "sinc_interp_kaiser", beta)
)
def test_phase_vocoder(self):
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
......
......@@ -237,7 +237,7 @@ class Tester(common_utils.TorchaudioTestCase):
torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method=invalid_resampling_method)
upsample_resample = torchaudio.transforms.Resample(
sample_rate, upsample_rate, resampling_method="sinc_interpolation"
sample_rate, upsample_rate, resampling_method="sinc_interp_hann"
)
up_sampled = upsample_resample(waveform)
......@@ -245,7 +245,7 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
downsample_resample = torchaudio.transforms.Resample(
sample_rate, downsample_rate, resampling_method="sinc_interpolation"
sample_rate, downsample_rate, resampling_method="sinc_interp_hann"
)
down_sampled = downsample_resample(waveform)
......
......@@ -53,7 +53,7 @@ class TransformsTestBase(TestBaseMixin):
assert _get_ratio(relative_diff < 1e-5) > 1e-5
@nested_params(
["sinc_interpolation", "kaiser_window"],
["sinc_interp_hann", "sinc_interp_kaiser"],
[16000, 44100],
)
def test_resample_identity(self, resampling_method, sample_rate):
......@@ -65,7 +65,7 @@ class TransformsTestBase(TestBaseMixin):
self.assertEqual(waveform, resampled)
@nested_params(
["sinc_interpolation", "kaiser_window"],
["sinc_interp_hann", "sinc_interp_kaiser"],
[None, torch.float64],
)
def test_resample_cache_dtype(self, resampling_method, dtype):
......
......@@ -1429,7 +1429,7 @@ def _get_sinc_resample_kernel(
gcd: int,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
resampling_method: str = "sinc_interp_hann",
beta: Optional[float] = None,
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None,
......@@ -1445,7 +1445,17 @@ def _get_sinc_resample_kernel(
"For more information, please refer to https://github.com/pytorch/audio/issues/1487."
)
if resampling_method not in ["sinc_interpolation", "kaiser_window"]:
if resampling_method in ["sinc_interpolation", "kaiser_window"]:
method_map = {
"sinc_interpolation": "sinc_interp_hann",
"kaiser_window": "sinc_interp_kaiser",
}
warnings.warn(
f'"{resampling_method}" resampling method name is being deprecated and replaced by '
f'"{method_map[resampling_method]}" in the next release. '
"The default behavior remains unchanged."
)
elif resampling_method not in ["sinc_interp_hann", "sinc_interp_kaiser"]:
raise ValueError("Invalid resampling method: {}".format(resampling_method))
orig_freq = int(orig_freq) // gcd
......@@ -1492,10 +1502,10 @@ def _get_sinc_resample_kernel(
# we do not use built in torch windows here as we need to evaluate the window
# at specific positions, not over a regular grid.
if resampling_method == "sinc_interpolation":
if resampling_method == "sinc_interp_hann":
window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
else:
# kaiser_window
# sinc_interp_kaiser
if beta is None:
beta = 14.769656459379492
beta_tensor = torch.tensor(float(beta))
......@@ -1549,7 +1559,7 @@ def resample(
new_freq: int,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
resampling_method: str = "sinc_interp_hann",
beta: Optional[float] = None,
) -> Tensor:
r"""Resamples the waveform at the new frequency using bandlimited interpolation. :cite:`RESAMPLE`.
......@@ -1571,7 +1581,7 @@ def resample(
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
resampling_method (str, optional): The resampling method to use.
Options: [``"sinc_interpolation"``, ``"kaiser_window"``] (Default: ``"sinc_interpolation"``)
Options: [``"sinc_interp_hann"``, ``"sinc_interp_kaiser"``] (Default: ``"sinc_interp_hann"``)
beta (float or None, optional): The shape parameter used for kaiser window.
Returns:
......
......@@ -942,7 +942,7 @@ class Resample(torch.nn.Module):
orig_freq (int, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (int, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method to use.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``"sinc_interpolation"``)
Options: [``sinc_interp_hann``, ``sinc_interp_kaiser``] (Default: ``"sinc_interp_hann"``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
......@@ -966,7 +966,7 @@ class Resample(torch.nn.Module):
self,
orig_freq: int = 16000,
new_freq: int = 16000,
resampling_method: str = "sinc_interpolation",
resampling_method: str = "sinc_interp_hann",
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
beta: Optional[float] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment