Commit 5c6e602c authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Speed up resample with kernel generation modification (#2553)

Summary:
Modification from pull request https://github.com/pytorch/audio/issues/2415 to improve resample.

Benchmarked for a 89% time reduction, tested in comparison to original resample method.

Pull Request resolved: https://github.com/pytorch/audio/pull/2553

Reviewed By: carolineechen

Differential Revision: D37997533

Pulled By: skim0514

fbshipit-source-id: ef4b719450ac26794db6ea01f9882509f4fda5cf
parent a2d6fee2
...@@ -1414,7 +1414,6 @@ def _get_sinc_resample_kernel( ...@@ -1414,7 +1414,6 @@ def _get_sinc_resample_kernel(
new_freq = int(new_freq) // gcd new_freq = int(new_freq) // gcd
assert lowpass_filter_width > 0 assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq) base_freq = min(orig_freq, new_freq)
# This will perform antialiasing filtering by removing the highest frequencies. # This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling # At first I thought I only needed this when downsampling, but when upsampling
...@@ -1445,31 +1444,33 @@ def _get_sinc_resample_kernel( ...@@ -1445,31 +1444,33 @@ def _get_sinc_resample_kernel(
# There is probably a way to evaluate those filters more efficiently, but this is kept for # There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work. # future work.
idx_dtype = dtype if dtype is not None else torch.float64 idx_dtype = dtype if dtype is not None else torch.float64
idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype)
for i in range(new_freq): idx = torch.arange(-width, width + orig_freq, dtype=idx_dtype, device=device)[None, None] / orig_freq
t = (-i / new_freq + idx / orig_freq) * base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
# we do not use built in torch windows here as we need to evaluate the window t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx
# at specific positions, not over a regular grid. t *= base_freq
if resampling_method == "sinc_interpolation": t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
else: # we do not use built in torch windows here as we need to evaluate the window
# kaiser_window # at specific positions, not over a regular grid.
if beta is None: if resampling_method == "sinc_interpolation":
beta = 14.769656459379492 window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
beta_tensor = torch.tensor(float(beta)) else:
window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor) # kaiser_window
t *= math.pi if beta is None:
kernel = torch.where(t == 0, torch.tensor(1.0).to(t), torch.sin(t) / t) beta = 14.769656459379492
kernel.mul_(window) beta_tensor = torch.tensor(float(beta))
kernels.append(kernel) window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor)
t *= math.pi
scale = base_freq / orig_freq scale = base_freq / orig_freq
kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale) kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
kernels *= window * scale
if dtype is None: if dtype is None:
kernels = kernels.to(dtype=torch.float32) kernels = kernels.to(dtype=torch.float32)
return kernels, width return kernels, width
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment