#include #include namespace { template void overdrive_cpu_kernel( at::TensorAccessor waveform_accessor, at::TensorAccessor temp_accessor, at::TensorAccessor last_in_accessor, at::TensorAccessor last_out_accessor, at::TensorAccessor output_waveform_accessor) { int64_t n_frames = waveform_accessor.size(1); int64_t n_channels = waveform_accessor.size(0); at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) { for (int64_t i_channel = begin; i_channel < end; ++i_channel) { for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel]; last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; output_waveform_accessor[i_channel][i_frame] = waveform_accessor[i_channel][i_frame] * 0.5 + last_out_accessor[i_channel] * 0.75; } } }); } void overdrive_core_loop_cpu( at::Tensor& waveform, at::Tensor& temp, at::Tensor& last_in, at::Tensor& last_out, at::Tensor& output_waveform) { AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] { overdrive_cpu_kernel( waveform.accessor(), temp.accessor(), last_in.accessor(), last_out.accessor(), output_waveform.accessor()); })); } } // namespace // Note: We want to avoid using "catch-all" kernel. // The following registration should be replaced with CPU specific registration. TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu); }