Unverified Commit aa0dd03b authored by Chin-Yun Yu's avatar Chin-Yun Yu Committed by GitHub
Browse files

Add filterbanks support to lfilter (#1587)

parent e7b43dde
......@@ -59,6 +59,15 @@ class Autograd(TestBaseMixin):
b = torch.tensor([0.4, 0.2, 0.9])
self.assert_grad(F.lfilter, (x, a, b))
def test_lfilter_filterbanks(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(F.lfilter, (x, a, b))
def test_biquad(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
......
......@@ -67,18 +67,22 @@ class Functional(TestBaseMixin):
assert output_signal.max() > 1
@parameterized.expand([
((44100,),),
((3, 44100),),
((2, 3, 44100),),
((1, 2, 3, 44100),)
((44100,), (4,), (44100,)),
((3, 44100), (4,), (3, 44100,)),
((2, 3, 44100), (4,), (2, 3, 44100,)),
((1, 2, 3, 44100), (4,), (1, 2, 3, 44100,)),
((44100,), (2, 4), (2, 44100)),
((3, 44100), (1, 4), (3, 1, 44100)),
((1, 2, 44100), (3, 4), (1, 2, 3, 44100))
])
def test_lfilter_shape(self, shape):
def test_lfilter_shape(self, input_shape, coeff_shape, target_shape):
torch.random.manual_seed(42)
waveform = torch.rand(*shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device)
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert shape == waveform.size() == output_waveform.size()
assert input_shape == waveform.size()
assert target_shape == output_waveform.size()
def test_lfilter_9th_order_filter_stability(self):
"""
......
......@@ -8,23 +8,28 @@ void host_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_channel = input_signal_windows.size(0);
int64_t n_samples_input = input_signal_windows.size(1);
int64_t n_samples_output = padded_output_waveform.size(1);
int64_t n_order = a_coeff_flipped.size(0);
int64_t n_batch = input_signal_windows.size(0);
int64_t n_channel = input_signal_windows.size(1);
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_samples_output = padded_output_waveform.size(2);
int64_t n_order = a_coeff_flipped.size(1);
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) {
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
int64_t offset_input = i_channel * n_samples_input;
int64_t offset_output = i_channel * n_samples_output;
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff];
for (int64_t i_batch = 0; i_batch < n_batch; i_batch++) {
for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) {
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
int64_t offset_input =
((i_batch * n_channel) + i_channel) * n_samples_input;
int64_t offset_output =
((i_batch * n_channel) + i_channel) * n_samples_output;
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff + i_channel * n_order];
}
output_data[offset_output + i_sample + n_order - 1] = a0;
}
output_data[offset_output + i_sample + n_order - 1] = a0;
}
}
}
......@@ -51,10 +56,11 @@ void cpu_lfilter_core_loop(
padded_output_waveform.dtype() == torch::kFloat64));
TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));
TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1));
TORCH_CHECK(
input_signal_windows.size(1) + a_coeff_flipped.size(0) - 1 ==
padded_output_waveform.size(1));
input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 ==
padded_output_waveform.size(2));
AT_DISPATCH_FLOATING_TYPES(
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
......@@ -67,16 +73,26 @@ void lfilter_core_generic_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_samples_input = input_signal_windows.size(1);
int64_t n_order = a_coeff_flipped.size(0);
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_order = a_coeff_flipped.size(1);
auto coeff = a_coeff_flipped.unsqueeze(2);
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
auto windowed_output_signal = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(i_sample, i_sample + n_order)});
auto o0 = input_signal_windows.index({torch::indexing::Slice(), i_sample})
.addmv(windowed_output_signal, a_coeff_flipped, 1, -1);
auto windowed_output_signal =
padded_output_waveform
.index(
{torch::indexing::Slice(),
torch::indexing::Slice(),
torch::indexing::Slice(i_sample, i_sample + n_order)})
.transpose(0, 1);
auto o0 =
input_signal_windows.index(
{torch::indexing::Slice(), torch::indexing::Slice(), i_sample}) -
at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1);
padded_output_waveform.index_put_(
{torch::indexing::Slice(), i_sample + n_order - 1}, o0);
{torch::indexing::Slice(),
torch::indexing::Slice(),
i_sample + n_order - 1},
o0);
}
}
......@@ -88,16 +104,17 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
const torch::Tensor& a_coeffs_normalized) {
auto device = waveform.device();
auto dtype = waveform.dtype();
int64_t n_channel = waveform.size(0);
int64_t n_sample = waveform.size(1);
int64_t n_order = a_coeffs_normalized.size(0);
int64_t n_batch = waveform.size(0);
int64_t n_channel = waveform.size(1);
int64_t n_sample = waveform.size(2);
int64_t n_order = a_coeffs_normalized.size(1);
int64_t n_sample_padded = n_sample + n_order - 1;
auto a_coeff_flipped = a_coeffs_normalized.flip(0).contiguous();
auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();
auto options = torch::TensorOptions().dtype(dtype).device(device);
auto padded_output_waveform =
torch::zeros({n_channel, n_sample_padded}, options);
torch::zeros({n_batch, n_channel, n_sample_padded}, options);
if (device.is_cpu()) {
cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
......@@ -108,6 +125,7 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});
ctx->save_for_backward({waveform, a_coeffs_normalized, output});
......@@ -122,8 +140,9 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
auto a_coeffs_normalized = saved[1];
auto y = saved[2];
int64_t n_channel = x.size(0);
int64_t n_order = a_coeffs_normalized.size(0);
int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = a_coeffs_normalized.size(1);
auto dx = torch::Tensor();
auto da = torch::Tensor();
......@@ -137,16 +156,16 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
F::PadFuncOptions({n_order - 1, 0}));
da = F::conv1d(
dyda.unsqueeze(0),
dy.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
dyda.view({1, n_batch * n_channel, -1}),
dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}
if (x.requires_grad()) {
dx = DifferentiableIIR::apply(dy.flip(1), a_coeffs_normalized).flip(1);
dx = DifferentiableIIR::apply(dy.flip(2), a_coeffs_normalized).flip(2);
}
return {dx, da};
......@@ -159,17 +178,18 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& b_coeffs) {
int64_t n_order = b_coeffs.size(0);
int64_t n_order = b_coeffs.size(1);
int64_t n_channel = b_coeffs.size(0);
namespace F = torch::nn::functional;
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();
auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
auto padded_waveform =
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));
auto output =
F::conv1d(
padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order}))
.squeeze(1);
auto output = F::conv1d(
padded_waveform,
b_coeff_flipped.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));
ctx->save_for_backward({waveform, b_coeffs, output});
return output;
......@@ -183,8 +203,9 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
auto b_coeffs = saved[1];
auto y = saved[2];
int64_t n_channel = x.size(0);
int64_t n_order = b_coeffs.size(0);
int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = b_coeffs.size(1);
auto dx = torch::Tensor();
auto db = torch::Tensor();
......@@ -194,19 +215,20 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
if (b_coeffs.requires_grad()) {
db = F::conv1d(
F::pad(x.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})),
dy.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
F::pad(x, F::PadFuncOptions({n_order - 1, 0}))
.view({1, n_batch * n_channel, -1}),
dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}
if (x.requires_grad()) {
dx = F::conv1d(
F::pad(dy.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})),
b_coeffs.view({1, 1, n_order}))
.squeeze(1);
F::pad(dy, F::PadFuncOptions({0, n_order - 1})),
b_coeffs.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));
}
return {dx, db};
......@@ -219,19 +241,27 @@ torch::Tensor lfilter_core(
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0));
TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes());
TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3);
TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1));
int64_t n_order = b_coeffs.size(0);
int64_t n_order = b_coeffs.size(1);
TORCH_INTERNAL_ASSERT(n_order > 0);
auto filtered_waveform =
DifferentiableFIR::apply(waveform, b_coeffs / a_coeffs[0]);
auto output =
DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]);
auto filtered_waveform = DifferentiableFIR::apply(
waveform,
b_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
auto output = DifferentiableIIR::apply(
filtered_waveform,
a_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
return output;
}
......
......@@ -855,13 +855,14 @@ def highpass_biquad(
def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor):
n_order = a_coeffs_flipped.size(0)
for i_sample, o0 in enumerate(input_signal_windows.t()):
n_order = a_coeffs_flipped.size(1)
a_coeffs_flipped = a_coeffs_flipped.unsqueeze(2)
for i_sample, o0 in enumerate(input_signal_windows.permute(2, 0, 1)):
windowed_output_signal = padded_output_waveform[
:, i_sample:i_sample + n_order
:, :, i_sample:i_sample + n_order
]
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
padded_output_waveform[:, i_sample + n_order - 1] = o0
o0 -= (windowed_output_signal.transpose(0, 1) @ a_coeffs_flipped)[..., 0].t()
padded_output_waveform[:, :, i_sample + n_order - 1] = o0
try:
......@@ -877,13 +878,13 @@ def _lfilter_core(
b_coeffs: Tensor,
) -> Tensor:
assert a_coeffs.size(0) == b_coeffs.size(0)
assert len(waveform.size()) == 2
assert a_coeffs.size() == b_coeffs.size()
assert len(waveform.size()) == 3
assert waveform.device == a_coeffs.device
assert b_coeffs.device == a_coeffs.device
n_channel, n_sample = waveform.size()
n_order = a_coeffs.size(0)
n_batch, n_channel, n_sample = waveform.size()
n_order = a_coeffs.size(1)
assert n_order > 0
# Pad the input and create output
......@@ -893,17 +894,18 @@ def _lfilter_core(
# Set up the coefficients matrix
# Flip coefficients' order
a_coeffs_flipped = a_coeffs.flip(0)
b_coeffs_flipped = b_coeffs.flip(0)
a_coeffs_flipped = a_coeffs.flip(1)
b_coeffs_flipped = b_coeffs.flip(1)
# calculate windowed_input_signal in parallel using convolution
input_signal_windows = torch.nn.functional.conv1d(
padded_waveform.unsqueeze(1),
b_coeffs_flipped.view(1, 1, -1)
).squeeze(1)
padded_waveform,
b_coeffs_flipped.unsqueeze(1),
groups=n_channel
)
input_signal_windows.div_(a_coeffs[0])
a_coeffs_flipped.div_(a_coeffs[0])
input_signal_windows.div_(a_coeffs[:, :1])
a_coeffs_flipped.div_(a_coeffs[:, :1])
if input_signal_windows.device == torch.device('cpu') and\
a_coeffs_flipped.device == torch.device('cpu') and\
......@@ -912,9 +914,10 @@ def _lfilter_core(
else:
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
output = padded_output_waveform[:, n_order - 1:]
output = padded_output_waveform[:, :, n_order - 1:]
return output
try:
_lfilter = torch.ops.torchaudio._lfilter
except RuntimeError as err:
......@@ -936,21 +939,32 @@ def lfilter(
Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
Returns:
Tensor: Waveform with dimension of ``(..., time)``.
Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or ``(..., time)`` otherwise.
"""
assert a_coeffs.size() == b_coeffs.size()
assert a_coeffs.ndim <= 2
if a_coeffs.ndim > 1:
waveform = torch.stack([waveform] * a_coeffs.shape[0], -2)
else:
a_coeffs = a_coeffs.unsqueeze(0)
b_coeffs = b_coeffs.unsqueeze(0)
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
waveform = waveform.reshape(-1, a_coeffs.shape[0], shape[-1])
output = _lfilter(waveform, a_coeffs, b_coeffs)
if clamp:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment