Unverified Commit aa0dd03b authored by Chin-Yun Yu's avatar Chin-Yun Yu Committed by GitHub
Browse files

Add filterbanks support to lfilter (#1587)

parent e7b43dde
...@@ -59,6 +59,15 @@ class Autograd(TestBaseMixin): ...@@ -59,6 +59,15 @@ class Autograd(TestBaseMixin):
b = torch.tensor([0.4, 0.2, 0.9]) b = torch.tensor([0.4, 0.2, 0.9])
self.assert_grad(F.lfilter, (x, a, b)) self.assert_grad(F.lfilter, (x, a, b))
def test_lfilter_filterbanks(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(F.lfilter, (x, a, b))
def test_biquad(self): def test_biquad(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
......
...@@ -67,18 +67,22 @@ class Functional(TestBaseMixin): ...@@ -67,18 +67,22 @@ class Functional(TestBaseMixin):
assert output_signal.max() > 1 assert output_signal.max() > 1
@parameterized.expand([ @parameterized.expand([
((44100,),), ((44100,), (4,), (44100,)),
((3, 44100),), ((3, 44100), (4,), (3, 44100,)),
((2, 3, 44100),), ((2, 3, 44100), (4,), (2, 3, 44100,)),
((1, 2, 3, 44100),) ((1, 2, 3, 44100), (4,), (1, 2, 3, 44100,)),
((44100,), (2, 4), (2, 44100)),
((3, 44100), (1, 4), (3, 1, 44100)),
((1, 2, 44100), (3, 4), (1, 2, 3, 44100))
]) ])
def test_lfilter_shape(self, shape): def test_lfilter_shape(self, input_shape, coeff_shape, target_shape):
torch.random.manual_seed(42) torch.random.manual_seed(42)
waveform = torch.rand(*shape, dtype=self.dtype, device=self.device) waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device) b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device) a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert shape == waveform.size() == output_waveform.size() assert input_shape == waveform.size()
assert target_shape == output_waveform.size()
def test_lfilter_9th_order_filter_stability(self): def test_lfilter_9th_order_filter_stability(self):
""" """
......
...@@ -8,25 +8,30 @@ void host_lfilter_core_loop( ...@@ -8,25 +8,30 @@ void host_lfilter_core_loop(
const torch::Tensor& input_signal_windows, const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped, const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) { torch::Tensor& padded_output_waveform) {
int64_t n_channel = input_signal_windows.size(0); int64_t n_batch = input_signal_windows.size(0);
int64_t n_samples_input = input_signal_windows.size(1); int64_t n_channel = input_signal_windows.size(1);
int64_t n_samples_output = padded_output_waveform.size(1); int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_order = a_coeff_flipped.size(0); int64_t n_samples_output = padded_output_waveform.size(2);
int64_t n_order = a_coeff_flipped.size(1);
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>(); scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>(); const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>(); const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
for (int64_t i_batch = 0; i_batch < n_batch; i_batch++) {
for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) { for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) {
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
int64_t offset_input = i_channel * n_samples_input; int64_t offset_input =
int64_t offset_output = i_channel * n_samples_output; ((i_batch * n_channel) + i_channel) * n_samples_input;
int64_t offset_output =
((i_batch * n_channel) + i_channel) * n_samples_output;
scalar_t a0 = input_data[offset_input + i_sample]; scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] * a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff]; a_coeff_flipped_data[i_coeff + i_channel * n_order];
} }
output_data[offset_output + i_sample + n_order - 1] = a0; output_data[offset_output + i_sample + n_order - 1] = a0;
} }
} }
}
} }
void cpu_lfilter_core_loop( void cpu_lfilter_core_loop(
...@@ -51,10 +56,11 @@ void cpu_lfilter_core_loop( ...@@ -51,10 +56,11 @@ void cpu_lfilter_core_loop(
padded_output_waveform.dtype() == torch::kFloat64)); padded_output_waveform.dtype() == torch::kFloat64));
TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0)); TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));
TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1));
TORCH_CHECK( TORCH_CHECK(
input_signal_windows.size(1) + a_coeff_flipped.size(0) - 1 == input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 ==
padded_output_waveform.size(1)); padded_output_waveform.size(2));
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] { input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
...@@ -67,16 +73,26 @@ void lfilter_core_generic_loop( ...@@ -67,16 +73,26 @@ void lfilter_core_generic_loop(
const torch::Tensor& input_signal_windows, const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped, const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) { torch::Tensor& padded_output_waveform) {
int64_t n_samples_input = input_signal_windows.size(1); int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_order = a_coeff_flipped.size(0); int64_t n_order = a_coeff_flipped.size(1);
auto coeff = a_coeff_flipped.unsqueeze(2);
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
auto windowed_output_signal = padded_output_waveform.index( auto windowed_output_signal =
padded_output_waveform
.index(
{torch::indexing::Slice(), {torch::indexing::Slice(),
torch::indexing::Slice(i_sample, i_sample + n_order)}); torch::indexing::Slice(),
auto o0 = input_signal_windows.index({torch::indexing::Slice(), i_sample}) torch::indexing::Slice(i_sample, i_sample + n_order)})
.addmv(windowed_output_signal, a_coeff_flipped, 1, -1); .transpose(0, 1);
auto o0 =
input_signal_windows.index(
{torch::indexing::Slice(), torch::indexing::Slice(), i_sample}) -
at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1);
padded_output_waveform.index_put_( padded_output_waveform.index_put_(
{torch::indexing::Slice(), i_sample + n_order - 1}, o0); {torch::indexing::Slice(),
torch::indexing::Slice(),
i_sample + n_order - 1},
o0);
} }
} }
...@@ -88,16 +104,17 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> { ...@@ -88,16 +104,17 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
const torch::Tensor& a_coeffs_normalized) { const torch::Tensor& a_coeffs_normalized) {
auto device = waveform.device(); auto device = waveform.device();
auto dtype = waveform.dtype(); auto dtype = waveform.dtype();
int64_t n_channel = waveform.size(0); int64_t n_batch = waveform.size(0);
int64_t n_sample = waveform.size(1); int64_t n_channel = waveform.size(1);
int64_t n_order = a_coeffs_normalized.size(0); int64_t n_sample = waveform.size(2);
int64_t n_order = a_coeffs_normalized.size(1);
int64_t n_sample_padded = n_sample + n_order - 1; int64_t n_sample_padded = n_sample + n_order - 1;
auto a_coeff_flipped = a_coeffs_normalized.flip(0).contiguous(); auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();
auto options = torch::TensorOptions().dtype(dtype).device(device); auto options = torch::TensorOptions().dtype(dtype).device(device);
auto padded_output_waveform = auto padded_output_waveform =
torch::zeros({n_channel, n_sample_padded}, options); torch::zeros({n_batch, n_channel, n_sample_padded}, options);
if (device.is_cpu()) { if (device.is_cpu()) {
cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform); cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
...@@ -108,6 +125,7 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> { ...@@ -108,6 +125,7 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
auto output = padded_output_waveform.index( auto output = padded_output_waveform.index(
{torch::indexing::Slice(), {torch::indexing::Slice(),
torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)}); torch::indexing::Slice(n_order - 1, torch::indexing::None)});
ctx->save_for_backward({waveform, a_coeffs_normalized, output}); ctx->save_for_backward({waveform, a_coeffs_normalized, output});
...@@ -122,8 +140,9 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> { ...@@ -122,8 +140,9 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
auto a_coeffs_normalized = saved[1]; auto a_coeffs_normalized = saved[1];
auto y = saved[2]; auto y = saved[2];
int64_t n_channel = x.size(0); int64_t n_batch = x.size(0);
int64_t n_order = a_coeffs_normalized.size(0); int64_t n_channel = x.size(1);
int64_t n_order = a_coeffs_normalized.size(1);
auto dx = torch::Tensor(); auto dx = torch::Tensor();
auto da = torch::Tensor(); auto da = torch::Tensor();
...@@ -137,16 +156,16 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> { ...@@ -137,16 +156,16 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
F::PadFuncOptions({n_order - 1, 0})); F::PadFuncOptions({n_order - 1, 0}));
da = F::conv1d( da = F::conv1d(
dyda.unsqueeze(0), dyda.view({1, n_batch * n_channel, -1}),
dy.unsqueeze(1), dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_channel)) F::Conv1dFuncOptions().groups(n_batch * n_channel))
.sum(1) .view({n_batch, n_channel, -1})
.squeeze(0) .sum(0)
.flip(0); .flip(1);
} }
if (x.requires_grad()) { if (x.requires_grad()) {
dx = DifferentiableIIR::apply(dy.flip(1), a_coeffs_normalized).flip(1); dx = DifferentiableIIR::apply(dy.flip(2), a_coeffs_normalized).flip(2);
} }
return {dx, da}; return {dx, da};
...@@ -159,17 +178,18 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> { ...@@ -159,17 +178,18 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
torch::autograd::AutogradContext* ctx, torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform, const torch::Tensor& waveform,
const torch::Tensor& b_coeffs) { const torch::Tensor& b_coeffs) {
int64_t n_order = b_coeffs.size(0); int64_t n_order = b_coeffs.size(1);
int64_t n_channel = b_coeffs.size(0);
namespace F = torch::nn::functional; namespace F = torch::nn::functional;
auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
auto padded_waveform = auto padded_waveform =
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));
auto output = auto output = F::conv1d(
F::conv1d( padded_waveform,
padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) b_coeff_flipped.unsqueeze(1),
.squeeze(1); F::Conv1dFuncOptions().groups(n_channel));
ctx->save_for_backward({waveform, b_coeffs, output}); ctx->save_for_backward({waveform, b_coeffs, output});
return output; return output;
...@@ -183,8 +203,9 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> { ...@@ -183,8 +203,9 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
auto b_coeffs = saved[1]; auto b_coeffs = saved[1];
auto y = saved[2]; auto y = saved[2];
int64_t n_channel = x.size(0); int64_t n_batch = x.size(0);
int64_t n_order = b_coeffs.size(0); int64_t n_channel = x.size(1);
int64_t n_order = b_coeffs.size(1);
auto dx = torch::Tensor(); auto dx = torch::Tensor();
auto db = torch::Tensor(); auto db = torch::Tensor();
...@@ -194,19 +215,20 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> { ...@@ -194,19 +215,20 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
if (b_coeffs.requires_grad()) { if (b_coeffs.requires_grad()) {
db = F::conv1d( db = F::conv1d(
F::pad(x.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})), F::pad(x, F::PadFuncOptions({n_order - 1, 0}))
dy.unsqueeze(1), .view({1, n_batch * n_channel, -1}),
F::Conv1dFuncOptions().groups(n_channel)) dy.view({n_batch * n_channel, 1, -1}),
.sum(1) F::Conv1dFuncOptions().groups(n_batch * n_channel))
.squeeze(0) .view({n_batch, n_channel, -1})
.flip(0); .sum(0)
.flip(1);
} }
if (x.requires_grad()) { if (x.requires_grad()) {
dx = F::conv1d( dx = F::conv1d(
F::pad(dy.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})), F::pad(dy, F::PadFuncOptions({0, n_order - 1})),
b_coeffs.view({1, 1, n_order})) b_coeffs.unsqueeze(1),
.squeeze(1); F::Conv1dFuncOptions().groups(n_channel));
} }
return {dx, db}; return {dx, db};
...@@ -219,19 +241,27 @@ torch::Tensor lfilter_core( ...@@ -219,19 +241,27 @@ torch::Tensor lfilter_core(
const torch::Tensor& b_coeffs) { const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device()); TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0)); TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes());
TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2); TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3);
TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1));
int64_t n_order = b_coeffs.size(0); int64_t n_order = b_coeffs.size(1);
TORCH_INTERNAL_ASSERT(n_order > 0); TORCH_INTERNAL_ASSERT(n_order > 0);
auto filtered_waveform = auto filtered_waveform = DifferentiableFIR::apply(
DifferentiableFIR::apply(waveform, b_coeffs / a_coeffs[0]); waveform,
b_coeffs /
auto output = a_coeffs.index(
DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]); {torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
auto output = DifferentiableIIR::apply(
filtered_waveform,
a_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
return output; return output;
} }
......
...@@ -855,13 +855,14 @@ def highpass_biquad( ...@@ -855,13 +855,14 @@ def highpass_biquad(
def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor): def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor):
n_order = a_coeffs_flipped.size(0) n_order = a_coeffs_flipped.size(1)
for i_sample, o0 in enumerate(input_signal_windows.t()): a_coeffs_flipped = a_coeffs_flipped.unsqueeze(2)
for i_sample, o0 in enumerate(input_signal_windows.permute(2, 0, 1)):
windowed_output_signal = padded_output_waveform[ windowed_output_signal = padded_output_waveform[
:, i_sample:i_sample + n_order :, :, i_sample:i_sample + n_order
] ]
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1) o0 -= (windowed_output_signal.transpose(0, 1) @ a_coeffs_flipped)[..., 0].t()
padded_output_waveform[:, i_sample + n_order - 1] = o0 padded_output_waveform[:, :, i_sample + n_order - 1] = o0
try: try:
...@@ -877,13 +878,13 @@ def _lfilter_core( ...@@ -877,13 +878,13 @@ def _lfilter_core(
b_coeffs: Tensor, b_coeffs: Tensor,
) -> Tensor: ) -> Tensor:
assert a_coeffs.size(0) == b_coeffs.size(0) assert a_coeffs.size() == b_coeffs.size()
assert len(waveform.size()) == 2 assert len(waveform.size()) == 3
assert waveform.device == a_coeffs.device assert waveform.device == a_coeffs.device
assert b_coeffs.device == a_coeffs.device assert b_coeffs.device == a_coeffs.device
n_channel, n_sample = waveform.size() n_batch, n_channel, n_sample = waveform.size()
n_order = a_coeffs.size(0) n_order = a_coeffs.size(1)
assert n_order > 0 assert n_order > 0
# Pad the input and create output # Pad the input and create output
...@@ -893,17 +894,18 @@ def _lfilter_core( ...@@ -893,17 +894,18 @@ def _lfilter_core(
# Set up the coefficients matrix # Set up the coefficients matrix
# Flip coefficients' order # Flip coefficients' order
a_coeffs_flipped = a_coeffs.flip(0) a_coeffs_flipped = a_coeffs.flip(1)
b_coeffs_flipped = b_coeffs.flip(0) b_coeffs_flipped = b_coeffs.flip(1)
# calculate windowed_input_signal in parallel using convolution # calculate windowed_input_signal in parallel using convolution
input_signal_windows = torch.nn.functional.conv1d( input_signal_windows = torch.nn.functional.conv1d(
padded_waveform.unsqueeze(1), padded_waveform,
b_coeffs_flipped.view(1, 1, -1) b_coeffs_flipped.unsqueeze(1),
).squeeze(1) groups=n_channel
)
input_signal_windows.div_(a_coeffs[0]) input_signal_windows.div_(a_coeffs[:, :1])
a_coeffs_flipped.div_(a_coeffs[0]) a_coeffs_flipped.div_(a_coeffs[:, :1])
if input_signal_windows.device == torch.device('cpu') and\ if input_signal_windows.device == torch.device('cpu') and\
a_coeffs_flipped.device == torch.device('cpu') and\ a_coeffs_flipped.device == torch.device('cpu') and\
...@@ -912,9 +914,10 @@ def _lfilter_core( ...@@ -912,9 +914,10 @@ def _lfilter_core(
else: else:
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
output = padded_output_waveform[:, n_order - 1:] output = padded_output_waveform[:, :, n_order - 1:]
return output return output
try: try:
_lfilter = torch.ops.torchaudio._lfilter _lfilter = torch.ops.torchaudio._lfilter
except RuntimeError as err: except RuntimeError as err:
...@@ -936,21 +939,32 @@ def lfilter( ...@@ -936,21 +939,32 @@ def lfilter(
Args: Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``. a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``. Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary). Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``. b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary). Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
Returns: Returns:
Tensor: Waveform with dimension of ``(..., time)``. Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or ``(..., time)`` otherwise.
""" """
assert a_coeffs.size() == b_coeffs.size()
assert a_coeffs.ndim <= 2
if a_coeffs.ndim > 1:
waveform = torch.stack([waveform] * a_coeffs.shape[0], -2)
else:
a_coeffs = a_coeffs.unsqueeze(0)
b_coeffs = b_coeffs.unsqueeze(0)
# pack batch # pack batch
shape = waveform.size() shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1]) waveform = waveform.reshape(-1, a_coeffs.shape[0], shape[-1])
output = _lfilter(waveform, a_coeffs, b_coeffs) output = _lfilter(waveform, a_coeffs, b_coeffs)
if clamp: if clamp:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment