Unverified Commit fd4ee240 authored by chin yun yu's avatar chin yun yu Committed by GitHub
Browse files

Use c++ backend for lfilter when c++ extension is available (#1319)

parent 720323c6
#include <torch/script.h> #include <torch/script.h>
#include <torch/torch.h>
namespace { namespace {
...@@ -62,6 +63,69 @@ void cpu_lfilter_core_loop( ...@@ -62,6 +63,69 @@ void cpu_lfilter_core_loop(
}); });
} }
void lfilter_core_generic_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_samples_input = input_signal_windows.size(1);
int64_t n_order = a_coeff_flipped.size(0);
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
auto windowed_output_signal = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(i_sample, i_sample + n_order)});
auto o0 = input_signal_windows.index({torch::indexing::Slice(), i_sample})
.addmv(windowed_output_signal, a_coeff_flipped, 1, -1);
padded_output_waveform.index_put_(
{torch::indexing::Slice(), i_sample + n_order - 1}, o0);
}
}
torch::Tensor lfilter_core(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0));
TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2);
auto device = waveform.device();
int64_t n_order = a_coeffs.size(0);
TORCH_INTERNAL_ASSERT(n_order > 0);
namespace F = torch::nn::functional;
auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));
auto padded_output_waveform = torch::zeros_like(padded_waveform);
auto a_coeff_flipped = a_coeffs.flip(0).contiguous();
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();
auto input_signal_windows =
F::conv1d(
padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order}))
.squeeze(1);
input_signal_windows.div_(a_coeffs[0]);
a_coeff_flipped.div_(a_coeffs[0]);
if (device.is_cpu()) {
cpu_lfilter_core_loop(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
} else {
lfilter_core_generic_loop(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
}
auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});
return output;
}
} // namespace } // namespace
// Note: We want to avoid using "catch-all" kernel. // Note: We want to avoid using "catch-all" kernel.
...@@ -69,3 +133,12 @@ void cpu_lfilter_core_loop( ...@@ -69,3 +133,12 @@ void cpu_lfilter_core_loop(
TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop); m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
} }
TORCH_LIBRARY(torchaudio, m) {
m.def(
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
}
TORCH_LIBRARY_IMPL(torchaudio, Math, m) {
m.impl("torchaudio::_lfilter", lfilter_core);
}
...@@ -825,30 +825,11 @@ except RuntimeError as err: ...@@ -825,30 +825,11 @@ except RuntimeError as err:
_lfilter_core_cpu_loop = _lfilter_core_generic_loop _lfilter_core_cpu_loop = _lfilter_core_generic_loop
def lfilter( def _lfilter_core(
waveform: Tensor, waveform: Tensor,
a_coeffs: Tensor, a_coeffs: Tensor,
b_coeffs: Tensor, b_coeffs: Tensor,
clamp: bool = True,
) -> Tensor: ) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation.
Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
Returns:
Tensor: Waveform with dimension of ``(..., time)``.
"""
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
assert a_coeffs.size(0) == b_coeffs.size(0) assert a_coeffs.size(0) == b_coeffs.size(0)
assert len(waveform.size()) == 2 assert len(waveform.size()) == 2
...@@ -886,6 +867,41 @@ def lfilter( ...@@ -886,6 +867,41 @@ def lfilter(
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
output = padded_output_waveform[:, n_order - 1:] output = padded_output_waveform[:, n_order - 1:]
return output
try:
_lfilter = torch.ops.torchaudio._lfilter
except RuntimeError as err:
assert str(err) == 'No such operator torchaudio::_lfilter'
_lfilter = _lfilter_core
def lfilter(
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
clamp: bool = True,
) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation.
Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
Returns:
Tensor: Waveform with dimension of ``(..., time)``.
"""
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
output = _lfilter(waveform, a_coeffs, b_coeffs)
if clamp: if clamp:
output = torch.clamp(output, min=-1.0, max=1.0) output = torch.clamp(output, min=-1.0, max=1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment