Unverified Commit 7763ed87 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Add F.resample torchscript test (#1516)

parent a21b08e3
...@@ -591,6 +591,28 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -591,6 +591,28 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor = common_utils.get_whitenoise(sample_rate=44100) tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
def test_resample_sinc(self):
def func(tensor):
sr1, sr2 = 16000., 8000.
return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation")
tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, tensor)
def test_resample_kaiser(self):
def func(tensor):
sr1, sr2 = 16000., 8000.
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window")
def func_beta(tensor):
sr1, sr2 = 16000., 8000.
beta = 6.
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta)
tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, tensor)
self._assert_consistency(func_beta, tensor)
@parameterized.expand([(True, ), (False, )]) @parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex): def test_phase_vocoder(self, test_paseudo_complex):
def func(tensor): def func(tensor):
......
...@@ -1328,9 +1328,6 @@ def _get_sinc_resample_kernel( ...@@ -1328,9 +1328,6 @@ def _get_sinc_resample_kernel(
orig_freq = int(orig_freq) // gcd orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd new_freq = int(new_freq) // gcd
if resampling_method == "kaiser_window" and beta is None:
beta = 14.769656459379492
assert lowpass_filter_width > 0 assert lowpass_filter_width > 0
kernels = [] kernels = []
base_freq = min(orig_freq, new_freq) base_freq = min(orig_freq, new_freq)
...@@ -1373,9 +1370,12 @@ def _get_sinc_resample_kernel( ...@@ -1373,9 +1370,12 @@ def _get_sinc_resample_kernel(
# at specific positions, not over a regular grid. # at specific positions, not over a regular grid.
if resampling_method == "sinc_interpolation": if resampling_method == "sinc_interpolation":
window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2 window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2
elif resampling_method == "kaiser_window": else:
beta = torch.tensor(beta, dtype=float) # kaiser_window
window = torch.i0(beta * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta) if beta is None:
beta = 14.769656459379492
beta_tensor = torch.tensor(float(beta))
window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor)
t *= math.pi t *= math.pi
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t) kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel.mul_(window) kernel.mul_(window)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment