Commit fc0720b4 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add sinc_impulse_response op (#2875)

Summary:
This commit adds `sinc_impulse_response`, which generates windowed-sinc low-pass filters for given cutoff frequencies.

Example usage:
 - [Filter Design Tutorial](https://output.circle-artifacts.com/output/job/c0085baa-5345-4aeb-bd44-448034caa9e1/artifacts/0/docs/tutorials/filter_design_tutorial.html)

Pull Request resolved: https://github.com/pytorch/audio/pull/2875

Reviewed By: carolineechen

Differential Revision: D41586631

Pulled By: mthrok

fbshipit-source-id: a9991dbe5b137b0b4679228ec37072a1da7e50bb
parent 19e1a84d
......@@ -34,3 +34,4 @@ DSP
adsr_envelope
extend_pitch
oscillator_bank
sinc_impulse_response
......@@ -62,3 +62,8 @@ class AutogradTestImpl(TestBaseMixin):
assert gradcheck(F.extend_pitch, (input, num_pitches))
assert gradcheck(F.extend_pitch, (input, pattern))
def test_sinc_ir(self):
cutoff = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype, requires_grad=True)
assert gradcheck(F.sinc_impulse_response, (cutoff, 513, False))
assert gradcheck(F.sinc_impulse_response, (cutoff, 513, True))
......@@ -18,3 +18,20 @@ def oscillator_bank(
waveform = amplitudes * np.sin(phases)
return waveform
def sinc_ir(cutoff: ArrayLike, window_size: int = 513, high_pass: bool = False):
if window_size % 2 == 0:
raise ValueError(f"`window_size` must be odd. Given: {window_size}")
half = window_size // 2
dtype = cutoff.dtype
idx = np.linspace(-half, half, window_size, dtype=dtype)
filt = np.sinc(cutoff[..., None] * idx[None, ...])
filt *= np.hamming(window_size).astype(dtype)[None, ...]
filt /= np.abs(filt.sum(axis=-1, keepdims=True))
if high_pass:
filt *= -1
filt[..., half] = 1.0 + filt[..., half]
return filt
......@@ -5,7 +5,14 @@ from parameterized import param, parameterized
from scipy import signal
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
from .dsp_utils import oscillator_bank as oscillator_bank_np
from .dsp_utils import oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np
def _prod(l):
r = 1
for p in l:
r *= p
return r
class FunctionalTestImpl(TestBaseMixin):
......@@ -361,6 +368,49 @@ class FunctionalTestImpl(TestBaseMixin):
output = F.extend_pitch(input, pat)
self.assertEqual(output, expected)
@nested_params(
# fmt: off
[(1,), (10,), (2, 5), (3, 5, 7)],
[1, 3, 65, 129, 257, 513, 1025],
[True, False],
# fmt: on
)
def test_sinc_ir_shape(self, input_shape, window_size, high_pass):
"""The shape of sinc_impulse_response is correct"""
numel = _prod(input_shape)
cutoff = torch.linspace(1, numel, numel).reshape(input_shape)
cutoff = cutoff.to(dtype=self.dtype, device=self.device)
filt = F.sinc_impulse_response(cutoff, window_size, high_pass)
assert filt.shape == input_shape + (window_size,)
@nested_params([True, False])
def test_sinc_ir_size(self, high_pass):
"""Increasing window size expand the filter at the ends. Core parts must stay same"""
cutoff = torch.tensor([200, 300, 400, 500, 600, 700])
cutoff = cutoff.to(dtype=self.dtype, device=self.device)
filt_5 = F.sinc_impulse_response(cutoff, 5, high_pass)
filt_3 = F.sinc_impulse_response(cutoff, 3, high_pass)
self.assertEqual(filt_3, filt_5[..., 1:-1])
@nested_params(
# fmt: off
[0, 0.1, 0.5, 0.9, 1.0],
[1, 3, 5, 65, 129, 257, 513, 1025, 2049],
[False, True],
# fmt: on
)
def test_sinc_ir_reference(self, cutoff, window_size, high_pass):
"""sinc_impulse_response produces the same result as reference implementation"""
cutoff = torch.tensor([cutoff], device=self.device, dtype=self.dtype)
hyp = F.sinc_impulse_response(cutoff, window_size, high_pass)
ref = sinc_ir_np(cutoff.cpu().numpy(), window_size, high_pass)
self.assertEqual(hyp, ref)
class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
......
......@@ -76,3 +76,8 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
self._assert_consistency(F.extend_pitch, (input, num_pitches))
self._assert_consistency(F.extend_pitch, (input, pattern))
self._assert_consistency(F.extend_pitch, (input, torch.tensor(pattern)))
def test_sinc_ir(self):
cutoff = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype)
self._assert_consistency(F.sinc_impulse_response, (cutoff, 513, False))
self._assert_consistency(F.sinc_impulse_response, (cutoff, 513, True))
from ._dsp import adsr_envelope, extend_pitch, oscillator_bank
from ._dsp import adsr_envelope, extend_pitch, oscillator_bank, sinc_impulse_response
from .functional import add_noise, barkscale_fbanks, convolve, fftconvolve
__all__ = [
......@@ -9,4 +9,5 @@ __all__ = [
"extend_pitch",
"fftconvolve",
"oscillator_bank",
"sinc_impulse_response",
]
......@@ -247,3 +247,42 @@ def extend_pitch(
mult = torch.tensor(pattern, dtype=base.dtype, device=base.device)
h_freq = base @ mult.unsqueeze(0)
return h_freq
def sinc_impulse_response(cutoff: torch.Tensor, window_size: int = 513, high_pass: bool = False):
"""Create windowed-sinc impulse response for given cutoff frequencies.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Args:
cutoff (Tensor): Cutoff frequencies for low-pass sinc filter.
window_size (int, optional): Size of the Hamming window to apply. Must be odd.
(Default: 513)
high_pass (bool, optional):
If ``True``, convert the resulting filter to high-pass.
Otherwise low-pass filter is returned. Default: ``False``.
Returns:
Tensor: A series of impulse responses. Shape: `(..., window_size)`.
"""
if window_size % 2 == 0:
raise ValueError(f"`window_size` must be odd. Given: {window_size}")
half = window_size // 2
device, dtype = cutoff.device, cutoff.dtype
idx = torch.linspace(-half, half, window_size, device=device, dtype=dtype)
filt = torch.special.sinc(cutoff.unsqueeze(-1) * idx.unsqueeze(0))
filt = filt * torch.hamming_window(window_size, device=device, dtype=dtype, periodic=False).unsqueeze(0)
filt = filt / filt.sum(dim=-1, keepdim=True).abs()
# High pass IR is obtained by subtracting low_pass IR from delta function.
# https://courses.engr.illinois.edu/ece401/fa2020/slides/lec10.pdf
if high_pass:
filt = -filt
filt[..., half] = 1.0 + filt[..., half]
return filt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment