Unverified Commit 8094751f authored by Chin-Yun Yu's avatar Chin-Yun Yu Committed by GitHub
Browse files

Add batch support to lfilter (#1638)

parent 15bc554f
from typing import Callable, Tuple from typing import Callable, Tuple
from functools import partial
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torch import Tensor from torch import Tensor
...@@ -62,6 +63,15 @@ class Autograd(TestBaseMixin): ...@@ -62,6 +63,15 @@ class Autograd(TestBaseMixin):
def test_lfilter_filterbanks(self): def test_lfilter_filterbanks(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))
def test_lfilter_batching(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([[0.7, 0.2, 0.6], a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]]) [0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9], b = torch.tensor([[0.4, 0.2, 0.9],
......
...@@ -217,3 +217,18 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -217,3 +217,18 @@ class TestFunctional(common_utils.TorchaudioTestCase):
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1)) batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency( self.assert_batch_consistency(
F.compute_kaldi_pitch, batch, sample_rate=sample_rate) F.compute_kaldi_pitch, batch, sample_rate=sample_rate)
def test_lfilter(self):
signal_length = 2048
torch.manual_seed(2434)
x = torch.randn(self.batch_size, signal_length)
a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3)
batchwise_output = F.lfilter(x, a, b, batching=True)
itemwise_output = torch.stack([
F.lfilter(x[i], a[i], b[i])
for i in range(self.batch_size)
])
self.assertEqual(batchwise_output, itemwise_output)
...@@ -80,7 +80,7 @@ class Functional(TestBaseMixin): ...@@ -80,7 +80,7 @@ class Functional(TestBaseMixin):
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device) waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs, batching=False)
assert input_shape == waveform.size() assert input_shape == waveform.size()
assert target_shape == output_waveform.size() assert target_shape == output_waveform.size()
......
...@@ -930,6 +930,7 @@ def lfilter( ...@@ -930,6 +930,7 @@ def lfilter(
a_coeffs: Tensor, a_coeffs: Tensor,
b_coeffs: Tensor, b_coeffs: Tensor,
clamp: bool = True, clamp: bool = True,
batching: bool = True
) -> Tensor: ) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation. r"""Perform an IIR filter by evaluating difference equation.
...@@ -948,6 +949,10 @@ def lfilter( ...@@ -948,6 +949,10 @@ def lfilter(
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary). Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
batching (bool, optional): Activate when coefficients are in 2D. If ``True``, then waveform should be at least
2D, and the size of second axis from last should equals to ``num_filters``.
The output can be expressed as ``output[..., i, :] = lfilter(waveform[..., i, :],
a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``)
Returns: Returns:
Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs`` Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs``
...@@ -957,7 +962,11 @@ def lfilter( ...@@ -957,7 +962,11 @@ def lfilter(
assert a_coeffs.ndim <= 2 assert a_coeffs.ndim <= 2
if a_coeffs.ndim > 1: if a_coeffs.ndim > 1:
waveform = torch.stack([waveform] * a_coeffs.shape[0], -2) if batching:
assert waveform.ndim > 1
assert waveform.shape[-2] == a_coeffs.shape[0]
else:
waveform = torch.stack([waveform] * a_coeffs.shape[0], -2)
else: else:
a_coeffs = a_coeffs.unsqueeze(0) a_coeffs = a_coeffs.unsqueeze(0)
b_coeffs = b_coeffs.unsqueeze(0) b_coeffs = b_coeffs.unsqueeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment