Unverified Commit 05bff83f authored by parmeet's avatar parmeet Committed by GitHub
Browse files

Add C++ lfilter core loop for CPU (#1244)

parent c3cb2015
......@@ -10,6 +10,7 @@ set(
sox/effects.cpp
sox/effects_chain.cpp
sox/types.cpp
lfilter.cpp
)
if(BUILD_TRANSDUCER)
......
#include <torch/script.h>
namespace {
template <typename scalar_t>
void host_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_channel = input_signal_windows.size(0);
int64_t n_samples_input = input_signal_windows.size(1);
int64_t n_samples_output = padded_output_waveform.size(1);
int64_t n_order = a_coeff_flipped.size(0);
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) {
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
int64_t offset_input = i_channel * n_samples_input;
int64_t offset_output = i_channel * n_samples_output;
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff];
}
output_data[offset_output + i_sample + n_order - 1] = a0;
}
}
}
void cpu_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
TORCH_CHECK(
input_signal_windows.device().is_cpu() &&
a_coeff_flipped.device().is_cpu() &&
padded_output_waveform.device().is_cpu());
TORCH_CHECK(
input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() &&
padded_output_waveform.is_contiguous());
TORCH_CHECK(
(input_signal_windows.dtype() == torch::kFloat32 ||
input_signal_windows.dtype() == torch::kFloat64) &&
(a_coeff_flipped.dtype() == torch::kFloat32 ||
a_coeff_flipped.dtype() == torch::kFloat64) &&
(padded_output_waveform.dtype() == torch::kFloat32 ||
padded_output_waveform.dtype() == torch::kFloat64));
TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));
TORCH_CHECK(
input_signal_windows.size(1) + a_coeff_flipped.size(0) - 1 ==
padded_output_waveform.size(1));
AT_DISPATCH_FLOATING_TYPES(
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
host_lfilter_core_loop<scalar_t>(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
});
}
} // namespace
// Note: We want to avoid using "catch-all" kernel.
// The following registration should be replaced with CPU specific registration.
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
}
......@@ -808,6 +808,23 @@ def highpass_biquad(
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor):
n_order = a_coeffs_flipped.size(0)
for i_sample, o0 in enumerate(input_signal_windows.t()):
windowed_output_signal = padded_output_waveform[
:, i_sample:i_sample + n_order
]
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
padded_output_waveform[:, i_sample + n_order - 1] = o0
try:
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
except RuntimeError as err:
assert str(err) == 'No such operator torchaudio::_lfilter_core_loop'
_lfilter_core_cpu_loop = _lfilter_core_generic_loop
def lfilter(
waveform: Tensor,
a_coeffs: Tensor,
......@@ -877,12 +894,13 @@ def lfilter(
input_signal_windows.div_(a_coeffs[0])
a_coeffs_flipped.div_(a_coeffs[0])
for i_sample, o0 in enumerate(input_signal_windows.t()):
windowed_output_signal = padded_output_waveform[
:, i_sample:i_sample + n_order
]
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
padded_output_waveform[:, i_sample + n_order - 1] = o0
if input_signal_windows.device == torch.device('cpu') and\
a_coeffs_flipped.device == torch.device('cpu') and\
padded_output_waveform.device == torch.device('cpu'):
_lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
else:
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
output = padded_output_waveform[:, n_order - 1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment