overdrive.cpp 2.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <torch/script.h>
#include <torch/torch.h>

namespace {

template <typename scalar_t>
void overdrive_cpu_kernel(
    at::TensorAccessor<scalar_t, 2> waveform_accessor,
    at::TensorAccessor<scalar_t, 2> temp_accessor,
    at::TensorAccessor<scalar_t, 1> last_in_accessor,
    at::TensorAccessor<scalar_t, 1> last_out_accessor,
    at::TensorAccessor<scalar_t, 2> output_waveform_accessor) {
  int64_t n_frames = waveform_accessor.size(1);
  int64_t n_channels = waveform_accessor.size(0);

  at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) {
    for (int64_t i_channel = begin; i_channel < end; ++i_channel) {
      for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
        last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] -
            last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel];
        last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame];
        output_waveform_accessor[i_channel][i_frame] =
            waveform_accessor[i_channel][i_frame] * 0.5 +
            last_out_accessor[i_channel] * 0.75;
      }
    }
  });
}

void overdrive_core_loop_cpu(
    at::Tensor& waveform,
    at::Tensor& temp,
    at::Tensor& last_in,
    at::Tensor& last_out,
    at::Tensor& output_waveform) {
  AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] {
                               overdrive_cpu_kernel<scalar_t>(
                                   waveform.accessor<scalar_t, 2>(),
                                   temp.accessor<scalar_t, 2>(),
                                   last_in.accessor<scalar_t, 1>(),
                                   last_out.accessor<scalar_t, 1>(),
                                   output_waveform.accessor<scalar_t, 2>());
                             }));
}

} // namespace

// Note: We want to avoid using "catch-all" kernel.
// The following registration should be replaced with CPU specific registration.
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
  m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu);
}