Unverified Commit 496b381a authored by hwangjeff's avatar hwangjeff Committed by GitHub
Browse files

Add basic filtfilt implementation (#1681)



* Add basic filtfilt implementation

* Add filtfilt to functional package; add tests
Co-authored-by: default avatarV G <vladislav.goncharenko@phystech.edu>
parent d1ce29a0
...@@ -86,7 +86,6 @@ complex_norm ...@@ -86,7 +86,6 @@ complex_norm
.. autofunction:: complex_norm .. autofunction:: complex_norm
magphase magphase
-------- --------
...@@ -152,6 +151,11 @@ equalizer_biquad ...@@ -152,6 +151,11 @@ equalizer_biquad
.. autofunction:: equalizer_biquad .. autofunction:: equalizer_biquad
filtfilt
--------
.. autofunction:: filtfilt
flanger flanger
------- -------
......
...@@ -79,6 +79,38 @@ class Autograd(TestBaseMixin): ...@@ -79,6 +79,38 @@ class Autograd(TestBaseMixin):
[0.7, 0.2, 0.6]]) [0.7, 0.2, 0.6]])
self.assert_grad(F.lfilter, (x, a, b)) self.assert_grad(F.lfilter, (x, a, b))
def test_filtfilt_a(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
a.requires_grad = True
self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
def test_filtfilt_b(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
b.requires_grad = True
self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
def test_filtfilt_all_inputs(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
self.assert_grad(F.filtfilt, (x, a, b))
def test_filtfilt_batching(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(F.filtfilt, (x, a, b))
def test_biquad(self): def test_biquad(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
......
...@@ -232,3 +232,18 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -232,3 +232,18 @@ class TestFunctional(common_utils.TorchaudioTestCase):
]) ])
self.assertEqual(batchwise_output, itemwise_output) self.assertEqual(batchwise_output, itemwise_output)
def test_filtfilt(self):
signal_length = 2048
torch.manual_seed(2434)
x = torch.randn(self.batch_size, signal_length)
a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3)
batchwise_output = F.filtfilt(x, a, b)
itemwise_output = torch.stack([
F.filtfilt(x[i], a[i], b[i])
for i in range(self.batch_size)
])
self.assertEqual(batchwise_output, itemwise_output)
...@@ -121,6 +121,93 @@ class Functional(TestBaseMixin): ...@@ -121,6 +121,93 @@ class Functional(TestBaseMixin):
yhat = F.lfilter(x, a, b, False) yhat = F.lfilter(x, a, b, False)
self.assertEqual(yhat, y, atol=1e-4, rtol=1e-5) self.assertEqual(yhat, y, atol=1e-4, rtol=1e-5)
def test_filtfilt_simple(self):
"""
Check that, for an arbitrary signal, applying filtfilt with filter coefficients
corresponding to a pure delay filter imparts no time delay.
"""
waveform = get_whitenoise(sample_rate=8000, n_channels=2, dtype=self.dtype).to(
device=self.device
)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device)
padded_waveform = torch.cat(
(waveform, torch.zeros(2, 3, dtype=self.dtype, device=self.device)), axis=1
)
output_waveform = F.filtfilt(padded_waveform, a_coeffs, b_coeffs)
self.assertEqual(output_waveform, padded_waveform, atol=1e-5, rtol=1e-5)
def test_filtfilt_filter_sinusoid(self):
"""
Check that, for a signal comprising two sinusoids, applying filtfilt
with appropriate filter coefficients correctly removes the higher-frequency
sinusoid while imparting no time delay.
"""
T = 1.0
samples = 1000
waveform_k0 = get_sinusoid(
frequency=5, sample_rate=samples // T, dtype=self.dtype, device=self.device
).squeeze(0)
waveform_k1 = get_sinusoid(
frequency=200,
sample_rate=samples // T,
dtype=self.dtype,
device=self.device,
).squeeze(0)
waveform = waveform_k0 + waveform_k1
# Transfer function numerator and denominator polynomial coefficients
# corresponding to 8th-order Butterworth filter with 100-cycle/T cutoff.
# Generated with
# >>> from scipy import signal
# >>> b_coeffs, a_coeffs = signal.butter(8, 0.2)
b_coeffs = torch.tensor(
[
2.39596441e-05,
1.91677153e-04,
6.70870035e-04,
1.34174007e-03,
1.67717509e-03,
1.34174007e-03,
6.70870035e-04,
1.91677153e-04,
2.39596441e-05,
],
dtype=self.dtype,
device=self.device,
)
a_coeffs = torch.tensor(
[
1.0,
-4.78451489,
10.44504107,
-13.45771989,
11.12933104,
-6.0252604,
2.0792738,
-0.41721716,
0.0372001,
],
dtype=self.dtype,
device=self.device,
)
# Extend waveform in each direction, preserving periodicity.
padded_waveform = torch.cat((waveform[:-1], waveform, waveform[1:]))
output_waveform = F.filtfilt(padded_waveform, a_coeffs, b_coeffs)
# Remove padding from output waveform; confirm that result
# closely matches waveform_k0.
self.assertEqual(
output_waveform[samples - 1: 2 * samples - 1],
waveform_k0,
atol=1e-3,
rtol=1e-3,
)
@parameterized.expand([(0., ), (1., ), (2., ), (3., )]) @parameterized.expand([(0., ), (1., ), (2., ), (3., )])
def test_spectogram_grad_at_zero(self, power): def test_spectogram_grad_at_zero(self, power):
"""The gradient of power spectrogram should not be nan but zero near x=0 """The gradient of power spectrogram should not be nan but zero near x=0
......
...@@ -334,6 +334,16 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -334,6 +334,16 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_filtfilt(self):
def func(tensor):
torch.manual_seed(296)
b_coeffs = torch.rand(4, device=tensor.device, dtype=tensor.dtype)
a_coeffs = torch.rand(4, device=tensor.device, dtype=tensor.dtype)
return F.filtfilt(tensor, a_coeffs, b_coeffs)
waveform = common_utils.get_whitenoise(sample_rate=8000)
self._assert_consistency(func, waveform)
def test_lowpass(self): def test_lowpass(self):
if self.dtype == torch.float64: if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64") raise unittest.SkipTest("This test is known to fail for float64")
......
...@@ -39,6 +39,7 @@ from .filtering import ( ...@@ -39,6 +39,7 @@ from .filtering import (
dcshift, dcshift,
deemph_biquad, deemph_biquad,
equalizer_biquad, equalizer_biquad,
filtfilt,
flanger, flanger,
gain, gain,
highpass_biquad, highpass_biquad,
...@@ -85,6 +86,7 @@ __all__ = [ ...@@ -85,6 +86,7 @@ __all__ = [
'dcshift', 'dcshift',
'deemph_biquad', 'deemph_biquad',
'equalizer_biquad', 'equalizer_biquad',
'filtfilt',
'flanger', 'flanger',
'gain', 'gain',
'highpass_biquad', 'highpass_biquad',
......
...@@ -642,6 +642,36 @@ def equalizer_biquad( ...@@ -642,6 +642,36 @@ def equalizer_biquad(
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def filtfilt(
waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True,
) -> Tensor:
r"""Apply an IIR filter forward and backward to a waveform.
Inspired by https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html
Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``.
Lower delay coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``.
Lower delay coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
Returns:
Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or ``(..., time)`` otherwise.
"""
forward_filtered = lfilter(waveform, a_coeffs, b_coeffs, clamp=False, batching=True)
backward_filtered = lfilter(
forward_filtered.flip(-1), a_coeffs, b_coeffs, clamp=clamp, batching=True,
).flip(-1)
return backward_filtered
def flanger( def flanger(
waveform: Tensor, waveform: Tensor,
sample_rate: int, sample_rate: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment