Unverified Commit 25a8adf6 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

[BC-Breaking] Ensure integer input frequencies for resample (#1857)

parent 483d8fae
...@@ -27,10 +27,10 @@ class Functional(TestBaseMixin): ...@@ -27,10 +27,10 @@ class Functional(TestBaseMixin):
new_sample_rate = sample_rate new_sample_rate = sample_rate
if up_scale_factor is not None: if up_scale_factor is not None:
new_sample_rate *= up_scale_factor new_sample_rate = int(new_sample_rate * up_scale_factor)
if down_scale_factor is not None: if down_scale_factor is not None:
new_sample_rate //= down_scale_factor new_sample_rate = int(new_sample_rate / down_scale_factor)
duration = 5 # seconds duration = 5 # seconds
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate) original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)
...@@ -439,25 +439,6 @@ class Functional(TestBaseMixin): ...@@ -439,25 +439,6 @@ class Functional(TestBaseMixin):
def test_resample_waveform_upsample_accuracy(self, resampling_method, i): def test_resample_waveform_upsample_accuracy(self, resampling_method, i):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method) self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)
def test_resample_no_warning(self):
sample_rate = 44100
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.resample(waveform, float(sample_rate), sample_rate / 2.)
assert len(w) == 0
def test_resample_warning(self):
"""resample should throw a warning if an input frequency is not of an integer value"""
sample_rate = 44100
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.resample(waveform, sample_rate, 5512.5)
assert len(w) == 1
@nested_params( @nested_params(
[0.5, 1.01, 1.3], [0.5, 1.01, 1.3],
[True, False], [True, False],
......
...@@ -659,7 +659,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -659,7 +659,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_resample_sinc(self): def test_resample_sinc(self):
def func(tensor): def func(tensor):
sr1, sr2 = 16000., 8000. sr1, sr2 = 16000, 8000
return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation") return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation")
tensor = common_utils.get_whitenoise(sample_rate=16000) tensor = common_utils.get_whitenoise(sample_rate=16000)
...@@ -667,11 +667,11 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -667,11 +667,11 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_resample_kaiser(self): def test_resample_kaiser(self):
def func(tensor): def func(tensor):
sr1, sr2 = 16000., 8000. sr1, sr2 = 16000, 8000
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window") return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window")
def func_beta(tensor): def func_beta(tensor):
sr1, sr2 = 16000., 8000. sr1, sr2 = 16000, 8000
beta = 6. beta = 6.
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta) return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta)
......
...@@ -84,7 +84,7 @@ class Transforms(TestBaseMixin): ...@@ -84,7 +84,7 @@ class Transforms(TestBaseMixin):
def test_Resample(self): def test_Resample(self):
sr1, sr2 = 16000, 8000 sr1, sr2 = 16000, 8000
tensor = common_utils.get_whitenoise(sample_rate=sr1) tensor = common_utils.get_whitenoise(sample_rate=sr1)
self._assert_consistency(T.Resample(float(sr1), float(sr2)), tensor) self._assert_consistency(T.Resample(sr1, sr2), tensor)
def test_ComplexNorm(self): def test_ComplexNorm(self):
tensor = torch.rand((1, 2, 201, 2)) tensor = torch.rand((1, 2, 201, 2))
......
...@@ -1471,8 +1471,8 @@ def compute_kaldi_pitch( ...@@ -1471,8 +1471,8 @@ def compute_kaldi_pitch(
def _get_sinc_resample_kernel( def _get_sinc_resample_kernel(
orig_freq: float, orig_freq: int,
new_freq: float, new_freq: int,
gcd: int, gcd: int,
lowpass_filter_width: int, lowpass_filter_width: int,
rolloff: float, rolloff: float,
...@@ -1482,16 +1482,13 @@ def _get_sinc_resample_kernel( ...@@ -1482,16 +1482,13 @@ def _get_sinc_resample_kernel(
dtype: Optional[torch.dtype] = None): dtype: Optional[torch.dtype] = None):
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn( raise Exception(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality " "Frequencies must be of integer type to ensure quality resampling computation. "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. " "To work around this, manually convert both frequencies to integer values "
"Using non-integer valued frequencies will throw an error in release 0.10. " "that maintain their resampling rate ratio before passing them into the function. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use " "Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` " "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5`. "
"For more information or to leave feedback about this change, please refer to " "For more information, please refer to https://github.com/pytorch/audio/issues/1487."
"https://github.com/pytorch/audio/issues/1487."
) )
if resampling_method not in ['sinc_interpolation', 'kaiser_window']: if resampling_method not in ['sinc_interpolation', 'kaiser_window']:
...@@ -1562,8 +1559,8 @@ def _get_sinc_resample_kernel( ...@@ -1562,8 +1559,8 @@ def _get_sinc_resample_kernel(
def _apply_sinc_resample_kernel( def _apply_sinc_resample_kernel(
waveform: Tensor, waveform: Tensor,
orig_freq: float, orig_freq: int,
new_freq: float, new_freq: int,
gcd: int, gcd: int,
kernel: Tensor, kernel: Tensor,
width: int, width: int,
...@@ -1589,8 +1586,8 @@ def _apply_sinc_resample_kernel( ...@@ -1589,8 +1586,8 @@ def _apply_sinc_resample_kernel(
def resample( def resample(
waveform: Tensor, waveform: Tensor,
orig_freq: float, orig_freq: int,
new_freq: float, new_freq: int,
lowpass_filter_width: int = 6, lowpass_filter_width: int = 6,
rolloff: float = 0.99, rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation", resampling_method: str = "sinc_interpolation",
...@@ -1606,8 +1603,8 @@ def resample( ...@@ -1606,8 +1603,8 @@ def resample(
Args: Args:
waveform (Tensor): The input signal of dimension `(..., time)` waveform (Tensor): The input signal of dimension `(..., time)`
orig_freq (float): The original frequency of the signal orig_freq (int): The original frequency of the signal
new_freq (float): The desired frequency new_freq (int): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. (Default: ``6``) but less efficient. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
...@@ -1736,7 +1733,7 @@ def pitch_shift( ...@@ -1736,7 +1733,7 @@ def pitch_shift(
win_length=win_length, win_length=win_length,
window=window, window=window,
length=len_stretch) length=len_stretch)
waveform_shift = resample(waveform_stretch, sample_rate // rate, float(sample_rate)) waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
shift_len = waveform_shift.size()[-1] shift_len = waveform_shift.size()[-1]
if shift_len > ori_len: if shift_len > ori_len:
waveform_shift = waveform_shift[..., :ori_len] waveform_shift = waveform_shift[..., :ori_len]
......
...@@ -815,8 +815,8 @@ class Resample(torch.nn.Module): ...@@ -815,8 +815,8 @@ class Resample(torch.nn.Module):
Alternatively, you could rewrite a transform that caches a higher precision kernel. Alternatively, you could rewrite a transform that caches a higher precision kernel.
Args: Args:
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``) orig_freq (int, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``) new_freq (int, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method to use. resampling_method (str, optional): The resampling method to use.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``) Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
...@@ -840,8 +840,8 @@ class Resample(torch.nn.Module): ...@@ -840,8 +840,8 @@ class Resample(torch.nn.Module):
def __init__( def __init__(
self, self,
orig_freq: float = 16000, orig_freq: int = 16000,
new_freq: float = 16000, new_freq: int = 16000,
resampling_method: str = 'sinc_interpolation', resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6, lowpass_filter_width: int = 6,
rolloff: float = 0.99, rolloff: float = 0.99,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment