Unverified Commit 8094751f authored by Chin-Yun Yu's avatar Chin-Yun Yu Committed by GitHub
Browse files

Add batch support to lfilter (#1638)

parent 15bc554f
from typing import Callable, Tuple
from functools import partial
import torch
from parameterized import parameterized
from torch import Tensor
......@@ -62,6 +63,15 @@ class Autograd(TestBaseMixin):
def test_lfilter_filterbanks(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))
def test_lfilter_batching(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9],
......
......@@ -217,3 +217,18 @@ class TestFunctional(common_utils.TorchaudioTestCase):
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(
F.compute_kaldi_pitch, batch, sample_rate=sample_rate)
def test_lfilter(self):
signal_length = 2048
torch.manual_seed(2434)
x = torch.randn(self.batch_size, signal_length)
a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3)
batchwise_output = F.lfilter(x, a, b, batching=True)
itemwise_output = torch.stack([
F.lfilter(x[i], a[i], b[i])
for i in range(self.batch_size)
])
self.assertEqual(batchwise_output, itemwise_output)
......@@ -80,7 +80,7 @@ class Functional(TestBaseMixin):
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs, batching=False)
assert input_shape == waveform.size()
assert target_shape == output_waveform.size()
......
......@@ -930,6 +930,7 @@ def lfilter(
a_coeffs: Tensor,
b_coeffs: Tensor,
clamp: bool = True,
batching: bool = True
) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation.
......@@ -948,6 +949,10 @@ def lfilter(
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
batching (bool, optional): Activate when coefficients are in 2D. If ``True``, then waveform should be at least
2D, and the size of second axis from last should equals to ``num_filters``.
The output can be expressed as ``output[..., i, :] = lfilter(waveform[..., i, :],
a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``)
Returns:
Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs``
......@@ -957,7 +962,11 @@ def lfilter(
assert a_coeffs.ndim <= 2
if a_coeffs.ndim > 1:
waveform = torch.stack([waveform] * a_coeffs.shape[0], -2)
if batching:
assert waveform.ndim > 1
assert waveform.shape[-2] == a_coeffs.shape[0]
else:
waveform = torch.stack([waveform] * a_coeffs.shape[0], -2)
else:
a_coeffs = a_coeffs.unsqueeze(0)
b_coeffs = b_coeffs.unsqueeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment