Unverified Commit 23e9ed34 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Implement faster CPU overdrive in C++ (#1299)

parent e83d557a
...@@ -11,6 +11,7 @@ set( ...@@ -11,6 +11,7 @@ set(
sox/effects_chain.cpp sox/effects_chain.cpp
sox/types.cpp sox/types.cpp
lfilter.cpp lfilter.cpp
overdrive.cpp
) )
if(BUILD_TRANSDUCER) if(BUILD_TRANSDUCER)
......
#include <torch/script.h>
#include <torch/torch.h>
namespace {
template <typename scalar_t>
void overdrive_cpu_kernel(
at::TensorAccessor<scalar_t, 2> waveform_accessor,
at::TensorAccessor<scalar_t, 2> temp_accessor,
at::TensorAccessor<scalar_t, 1> last_in_accessor,
at::TensorAccessor<scalar_t, 1> last_out_accessor,
at::TensorAccessor<scalar_t, 2> output_waveform_accessor) {
int64_t n_frames = waveform_accessor.size(1);
int64_t n_channels = waveform_accessor.size(0);
at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) {
for (int64_t i_channel = begin; i_channel < end; ++i_channel) {
for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] -
last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel];
last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame];
output_waveform_accessor[i_channel][i_frame] =
waveform_accessor[i_channel][i_frame] * 0.5 +
last_out_accessor[i_channel] * 0.75;
}
}
});
}
void overdrive_core_loop_cpu(
at::Tensor& waveform,
at::Tensor& temp,
at::Tensor& last_in,
at::Tensor& last_out,
at::Tensor& output_waveform) {
AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] {
overdrive_cpu_kernel<scalar_t>(
waveform.accessor<scalar_t, 2>(),
temp.accessor<scalar_t, 2>(),
last_in.accessor<scalar_t, 1>(),
last_out.accessor<scalar_t, 1>(),
output_waveform.accessor<scalar_t, 2>());
}));
}
} // namespace
// Note: We want to avoid using "catch-all" kernel.
// The following registration should be replaced with CPU specific registration.
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu);
}
...@@ -939,6 +939,26 @@ def lowpass_biquad( ...@@ -939,6 +939,26 @@ def lowpass_biquad(
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def _overdrive_core_loop_generic(
waveform: Tensor,
temp: Tensor,
last_in: Tensor,
last_out: Tensor,
output_waveform: Tensor
):
for i in range(waveform.shape[-1]):
last_out = temp[:, i] - last_in + 0.995 * last_out
last_in = temp[:, i]
output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75
try:
_overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop
except RuntimeError as err:
assert str(err) == 'No such operator torchaudio::_overdrive_core_loop'
_overdrive_core_loop_cpu = _overdrive_core_loop_generic
def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor: def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor:
r"""Apply a overdrive effect to the audio. Similar to SoX implementation. r"""Apply a overdrive effect to the audio. Similar to SoX implementation.
This effect applies a non linear distortion to the audio signal. This effect applies a non linear distortion to the audio signal.
...@@ -981,11 +1001,11 @@ def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor: ...@@ -981,11 +1001,11 @@ def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor:
output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device) output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device)
# TODO: Implement a torch CPP extension # Uses CPU optimized loop function if available for CPU device
for i in range(waveform.shape[-1]): if device == torch.device('cpu'):
last_out = temp[:, i] - last_in + 0.995 * last_out _overdrive_core_loop_cpu(waveform, temp, last_in, last_out, output_waveform)
last_in = temp[:, i] else:
output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75 _overdrive_core_loop_generic(waveform, temp, last_in, last_out, output_waveform)
return output_waveform.clamp(min=-1, max=1).view(actual_shape) return output_waveform.clamp(min=-1, max=1).view(actual_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment