integrated_conv_cpu.cpp 3.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <torch/extension.h>



// forward of integrated_conv.  """... """ comment of `integrated_conv`
// in integrated_conv.py documents the behavior of this function.
torch::Tensor integrated_conv_cpu(torch::Tensor input,
                                  torch::Tensor pos_add,
                                  torch::Tensor pos_mul) {
  TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
  TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
  TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
  TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
  const int N = input.size(0),
      C = input.size(1) / 2,
      H = input.size(2),
      W = input.size(3),
      kH = pos_add.size(1),
      kW = pos_add.size(2);
  TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1);
  TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels");
  TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
              pos_mul.size(1) == kH && pos_mul.size(2) == kW,
              "Input sizes mismatch.");
  TORCH_CHECK(pos_add.device() == input.device() &&
              pos_mul.device() == pos_add.device(),
              "Input devices mismatch");
Daniel Povey's avatar
Daniel Povey committed
28
29
30
  auto scalar_t = input.scalar_type();
  TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
              pos_mul.scalar_type() == scalar_t,
31
32
33
34
35
36
37
              "Input dtypes mismatch");

  torch::Tensor output = torch::empty({N, C, H, W},
                                      torch::TensorOptions().dtype(scalar_t).device(input.device()));

  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_cpu_loop", ([&] {
        auto input_a = input.accessor<scalar_t, 4>(),
Daniel Povey's avatar
Daniel Povey committed
38
39
40
41
            output_a = output.accessor<scalar_t, 4>();
        auto pos_add_a = pos_add.accessor<scalar_t, 3>(),
            pos_mul_a = pos_add.accessor<scalar_t, 3>();

42
43
44
45
46
47
48
49
50
51
52
53
54
        for (int n = 0; n < N; n++) {
          for (int c = 0; c < C; c++) {
            auto src_input_a = input_a[n][c],
                this_pos_add_a = pos_add_a[c],
                this_pos_mul_a = pos_mul_a[c],
                this_output_a = output_a[n][c];
            for (int h = 0; h < H; h++) {
              for (int w = 0; w < W; w++) {
                scalar_t dest = input_a[n][c + C][h][w],
                    sum = 0.0;
                for (int kh = 0; kh < kH; kh++) {
                  int src_h = h + kh - kH / 2;
                  for (int kw = 0; kw < kW; kw++) {
Daniel Povey's avatar
Daniel Povey committed
55
                    int src_w = w + kw - kW / 2;
56
57
58
59
                    scalar_t src = 0.0;
                    if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
                        static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
                      src = src_input_a[src_h][src_w];
Daniel Povey's avatar
Daniel Povey committed
60
                    scalar_t relu = src + dest + this_pos_add_a[kh][kw];
61
                    if (relu > 0.0)
Daniel Povey's avatar
Daniel Povey committed
62
                      sum += relu * this_pos_mul_a[kh][kw];
63
64
                  }
                }
Daniel Povey's avatar
Daniel Povey committed
65
                this_output_a[h][w] = sum;
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
              }
            }
          }
        }
      }));
  return output;
}

// backward of integrated_conv; returns (grad_input, grad_pos_add, grad_pos_mul).
std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
                                                        torch::Tensor pos_add,
                                                        torch::Tensor pos_mul,
                                                        torch::Tensor grad_output) {
  // TODO.
  return std::vector<torch::Tensor>();
}





PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("integrated_conv_cpu", &integrated_conv_cpu, "Integrated convolution forward function (CPU)");
Daniel Povey's avatar
Daniel Povey committed
89
  m.def("integrated_conv_backward_cpu", &integrated_conv_backward_cpu, "Integrated convolution backward function (CPU)");
90
}