integrated_conv_cpu.cpp 7.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <torch/extension.h>



// forward of integrated_conv.  """... """ comment of `integrated_conv`
// in integrated_conv.py documents the behavior of this function.
torch::Tensor integrated_conv_cpu(torch::Tensor input,
                                  torch::Tensor pos_add,
                                  torch::Tensor pos_mul) {
  TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
  TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
  TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
  TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
  const int N = input.size(0),
      C = input.size(1) / 2,
      H = input.size(2),
      W = input.size(3),
      kH = pos_add.size(1),
      kW = pos_add.size(2);
  TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1);
  TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels");
  TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
              pos_mul.size(1) == kH && pos_mul.size(2) == kW,
              "Input sizes mismatch.");
  TORCH_CHECK(pos_add.device() == input.device() &&
              pos_mul.device() == pos_add.device(),
              "Input devices mismatch");
Daniel Povey's avatar
Daniel Povey committed
28
29
30
  auto scalar_t = input.scalar_type();
  TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
              pos_mul.scalar_type() == scalar_t,
31
32
33
34
35
36
37
              "Input dtypes mismatch");

  torch::Tensor output = torch::empty({N, C, H, W},
                                      torch::TensorOptions().dtype(scalar_t).device(input.device()));

  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_cpu_loop", ([&] {
        auto input_a = input.accessor<scalar_t, 4>(),
Daniel Povey's avatar
Daniel Povey committed
38
39
            output_a = output.accessor<scalar_t, 4>();
        auto pos_add_a = pos_add.accessor<scalar_t, 3>(),
Daniel Povey's avatar
Daniel Povey committed
40
            pos_mul_a = pos_mul.accessor<scalar_t, 3>();
Daniel Povey's avatar
Daniel Povey committed
41

42
43
44
45
46
47
48
49
50
51
52
53
54
        for (int n = 0; n < N; n++) {
          for (int c = 0; c < C; c++) {
            auto src_input_a = input_a[n][c],
                this_pos_add_a = pos_add_a[c],
                this_pos_mul_a = pos_mul_a[c],
                this_output_a = output_a[n][c];
            for (int h = 0; h < H; h++) {
              for (int w = 0; w < W; w++) {
                scalar_t dest = input_a[n][c + C][h][w],
                    sum = 0.0;
                for (int kh = 0; kh < kH; kh++) {
                  int src_h = h + kh - kH / 2;
                  for (int kw = 0; kw < kW; kw++) {
Daniel Povey's avatar
Daniel Povey committed
55
                    int src_w = w + kw - kW / 2;
56
57
58
59
                    scalar_t src = 0.0;
                    if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
                        static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
                      src = src_input_a[src_h][src_w];
Daniel Povey's avatar
Daniel Povey committed
60
                    scalar_t relu = src + dest + this_pos_add_a[kh][kw];
61
                    if (relu > 0.0)
Daniel Povey's avatar
Daniel Povey committed
62
                      sum += relu * this_pos_mul_a[kh][kw];
63
64
                  }
                }
Daniel Povey's avatar
Daniel Povey committed
65
                this_output_a[h][w] = sum;
66
67
68
69
70
71
72
73
74
75
76
77
78
              }
            }
          }
        }
      }));
  return output;
}

// backward of integrated_conv; returns (grad_input, grad_pos_add, grad_pos_mul).
std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
                                                        torch::Tensor pos_add,
                                                        torch::Tensor pos_mul,
                                                        torch::Tensor grad_output) {
Daniel Povey's avatar
Daniel Povey committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
  TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
  TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
  TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
  TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
  const int N = input.size(0),
      C = input.size(1) / 2,
      H = input.size(2),
      W = input.size(3),
      kH = pos_add.size(1),
      kW = pos_add.size(2);
  TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1);
  TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels");
  TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
              pos_mul.size(1) == kH && pos_mul.size(2) == kW,
              "Input sizes mismatch.");
  TORCH_CHECK(pos_add.device() == input.device() &&
              pos_mul.device() == pos_add.device(),
              "Input devices mismatch");
  auto scalar_t = input.scalar_type();
  TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
              pos_mul.scalar_type() == scalar_t,
              "Input dtypes mismatch");
  TORCH_CHECK(grad_output.dim() == 4 && grad_output.size(0) == N
              && grad_output.size(1) == C && grad_output.size(2) == H
              && grad_output.size(3) == W);



  torch::Tensor grad_input = torch::zeros({N, 2*C, H, W},
                                          torch::TensorOptions().dtype(scalar_t).device(input.device())),
      grad_pos_add = torch::zeros({C, kH, kW},
                                  torch::TensorOptions().dtype(scalar_t).device(input.device())),
      grad_pos_mul = torch::zeros({C, kH, kW},
                                  torch::TensorOptions().dtype(scalar_t).device(input.device()));


  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_cpu_loop", ([&] {
        auto input_a = input.accessor<scalar_t, 4>(),
            grad_output_a = grad_output.accessor<scalar_t, 4>(),
            grad_input_a = grad_input.accessor<scalar_t, 4>();
        auto pos_add_a = pos_add.accessor<scalar_t, 3>(),
            pos_mul_a = pos_mul.accessor<scalar_t, 3>(),
            grad_pos_add_a = grad_pos_add.accessor<scalar_t, 3>(),
            grad_pos_mul_a = grad_pos_mul.accessor<scalar_t, 3>();

        for (int n = 0; n < N; n++) {
          for (int c = 0; c < C; c++) {

            for (int h = 0; h < H; h++) {
              for (int w = 0; w < W; w++) {
                scalar_t dest = input_a[n][c + C][h][w],
                    dest_grad = 0.0,  // to be multiplied by this_output_grad later..
                    this_grad_output = grad_output_a[n][c][h][w];
                for (int kh = 0; kh < kH; kh++) {
                  int src_h = h + kh - kH / 2;
                  for (int kw = 0; kw < kW; kw++) {
                    int src_w = w + kw - kW / 2;
                    scalar_t src = 0.0;
                    if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
                        static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
                      src = input_a[n][c][src_h][src_w];
                    scalar_t relu = src + dest + pos_add_a[c][kh][kw];
                    if (relu >= 0.0) {
                      scalar_t pos_mul_val = pos_mul_a[c][kh][kw];
                      dest_grad += pos_mul_val;  // will later multiply by this_output_grad
                      grad_pos_add_a[c][kh][kw] += this_grad_output * pos_mul_val;
                      grad_pos_mul_a[c][kh][kw] += this_grad_output * relu;
                      if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
                          static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
                        grad_input_a[n][c][src_h][src_w] += this_grad_output * pos_mul_val;
                    }
                  }
                }
                grad_input_a[n][c + C][h][w] += dest_grad * this_grad_output;
              }
            }
          }
        }
      }));

  return std::vector<torch::Tensor>({grad_input, grad_pos_add, grad_pos_mul});
160
161
162
163
164
165
166
167
}





PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("integrated_conv_cpu", &integrated_conv_cpu, "Integrated convolution forward function (CPU)");
Daniel Povey's avatar
Daniel Povey committed
168
  m.def("integrated_conv_backward_cpu", &integrated_conv_backward_cpu, "Integrated convolution backward function (CPU)");
169
}