learned_nonlin_cpu.cpp 9.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <torch/extension.h>



// forward of learned_nonlin.  See """... """ comment of `learned_nonlin` in
// learned_nonlin.py for documentation of the behavior of this function.
torch::Tensor learned_nonlin_cpu(torch::Tensor input,
                                 torch::Tensor params) {
  TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
  TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
  TORCH_CHECK(params.size(1) >= 3 &&
              ((params.size(1) - 1) & (params.size(1) - 2)) == 0,
              "params.size(1) has invalid value, must be a power of 2 plus 1.");
  TORCH_CHECK(params.size(0) == input.size(1),
              "params vs input channels mismatch");

  TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
  TORCH_CHECK(params.device().is_cpu(), "Params must be a CPU tensor");

  const int B = input.size(0),
      C = input.size(1),
      T = input.size(2),
      N = params.size(1) - 1,
      K = N / 2;

  auto scalar_t = input.scalar_type();
  auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());

  torch::Tensor y_vals = torch::empty({C, N}, opts),
    output = torch::empty({B, C, T}, opts);

  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_cpu_loop", ([&] {
        auto params_a = params.accessor<scalar_t, 2>(),
            y_vals_a = y_vals.accessor<scalar_t, 2>();
        for (int c = 0; c < C; c++) {
          scalar_t sum_negative = 0.0,
37
38
              sum_positive = 0.0,
              scale = exp(params_a[c][0]);
39
          for (int i = 0; i < K; i++) {
Daniel Povey's avatar
Daniel Povey committed
40
41
42
43
            y_vals_a[c][K + i] = sum_positive;
            y_vals_a[c][K - i] = sum_negative;
            sum_positive += params_a[c][1 + K + i] * scale;
            sum_negative -= params_a[c][K - i] * scale;
44
          }
Daniel Povey's avatar
Daniel Povey committed
45
46
47
48
49
          // Let the reference point for y_vals_a[c][0] be -K, although the
          // interval actually starts at -(K-1).  This reference point is
          // arbitrary but using it makes our lives easier when processing the
          // data.
          y_vals_a[c][0] = sum_negative;
50
51
52
53
54
55
56
        }

        auto input_a = input.accessor<scalar_t, 3>(),
            output_a = output.accessor<scalar_t, 3>();

        for (int b = 0; b < B; b++) {
          for (int c = 0; c < C; c++) {
57
58
            scalar_t scale = exp(params_a[c][0]),
                inv_scale = 1.0 / scale;
59
60
61
62
63
64
            for (int t = 0; t < T; t++) {
              // `x` is the scaled input x plus an offset so that -K maps to 0.
              //  Note: the discontinuities in our function are at -(K-1) ... +(K+1),
              // so in a sense -K and +K are not special, but we include those
              // extra values as an easy way to handle the semi-infinite regions
              // that are < -(K-1) and > (K-1)
65
66
67
68
69
70
              scalar_t x = input_a[b][c][t] * inv_scale + K,
                  x_trunc = x;
              if (x_trunc < 0) x_trunc = 0;
              else if (x_trunc >= N) x_trunc = N - 1;
              // C++ rounds toward zero.
              int n = (int) x_trunc;
71
              // OK, at this point, 0 <= min < 2*K.
72
              scalar_t y = (x - n) * scale * params_a[c][n + 1] + y_vals_a[c][n];
73
74
              /* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n,
                   x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */
75
              output_a[b][c][t] = y;
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            }
          }
        }}));
  return output;
}


// backward of learned_nonlin.  Returns (input_grad, params_grad)

std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
                                                       torch::Tensor params,
                                                       torch::Tensor output_grad) {
  TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
  TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
  TORCH_CHECK(params.size(1) >= 3 &&
              ((params.size(1) - 1) & (params.size(1) - 2)) == 0,
              "params.size(1) has invalid value, must be a power of 2 plus 1.");
  TORCH_CHECK(params.size(0) == input.size(1),
              "params vs input channels mismatch");
  TORCH_CHECK(input.sizes() == output_grad.sizes(),
              "Output-grad vs. input sizes mismatch.");

  TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
  TORCH_CHECK(params.device().is_cpu(), "Params must be a CPU tensor");
  TORCH_CHECK(output_grad.device().is_cpu(), "Output-grad must be a CPU tensor");

  const int B = input.size(0),
      C = input.size(1),
      T = input.size(2),
      N = params.size(1) - 1,
      K = N / 2;

  auto scalar_t = input.scalar_type();
  auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());

  torch::Tensor y_vals = torch::empty({C, N}, opts),
      y_vals_grad = torch::zeros({C, N}, opts),
      params_grad = torch::zeros({C, N + 1}, opts),
      input_grad = torch::zeros({B, C, T}, opts);

  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_backward_cpu_loop", ([&] {
        auto params_a = params.accessor<scalar_t, 2>(),
Daniel Povey's avatar
Daniel Povey committed
118
            params_grad_a = params_grad.accessor<scalar_t, 2>(),
119
            y_vals_a = y_vals.accessor<scalar_t, 2>(),
Daniel Povey's avatar
Daniel Povey committed
120
            y_vals_grad_a = y_vals_grad.accessor<scalar_t, 2>();
121
122
        for (int c = 0; c < C; c++) {
          scalar_t sum_negative = 0.0,
123
124
              sum_positive = 0.0,
              scale = exp(params_a[c][0]);
125
          for (int i = 0; i < K; i++) {
Daniel Povey's avatar
Daniel Povey committed
126
127
            y_vals_a[c][K + i] = sum_positive;
            y_vals_a[c][K - i] = sum_negative;
128
129
            sum_positive += params_a[c][1 + K + i] * scale;
            sum_negative -= params_a[c][K - i] * scale;
130
          }
Daniel Povey's avatar
Daniel Povey committed
131
          // the reference point for the lowest, half-infinite interval (the one
Daniel Povey's avatar
Daniel Povey committed
132
133
134
          // starting at x=-(K-1) is x=-K; this is arbitrary but makes the
          // computation more regular.
          y_vals_a[c][0] = sum_negative;
135
136
137
138
139
140
141
142
        }

        auto input_a = input.accessor<scalar_t, 3>(),
            output_grad_a = output_grad.accessor<scalar_t, 3>(),
            input_grad_a = input_grad.accessor<scalar_t, 3>();

        for (int b = 0; b < B; b++) {
          for (int c = 0; c < C; c++) {
143
144
145
146
            scalar_t scale = exp(params_a[c][0]),
                inv_scale = 1.0 / scale,
                inv_scale_grad = 0.0,
                scale_grad = 0.0;
147
148
149
150
151
152
            for (int t = 0; t < T; t++) {
              // `x` is the scaled input x plus an offset so that -K maps to 0.
              //  Note: the discontinuities in our function are at -(K-1) ... +(K+1),
              // so in a sense -K and +K are not special, but we include those
              // extra values as an easy way to handle the semi-infinite regions
              // that are < -(K-1) and > (K-1)
Daniel Povey's avatar
Daniel Povey committed
153
154
              scalar_t input = input_a[b][c][t],
                  x = input * inv_scale + K,
155
156
157
158
159
160
161
                  y_grad = output_grad_a[b][c][t],
                  x_trunc = x;
              if (x_trunc < 0) x_trunc = 0;
              else if (x_trunc >= N) x_trunc = N - 1;
              // C++ rounds toward zero.
              int n = (int) x_trunc;
              // OK, at this point, 0 <= n < 2*K.
Daniel Povey's avatar
Daniel Povey committed
162
              // backprop for:
163
164
165
166
167
168
169
              // scalar_t x_residual_scaled = (x - (scalar_t)n) * scale
              // scalar_t y = x_residual_scaled * params_a[c][n + 1] + y_vals_a[c][n];
              scalar_t x_residual_scaled = (x - n) * scale,
                  x_residual_scaled_grad = y_grad * params_a[c][n + 1],
                  x_grad = x_residual_scaled_grad * scale;
              scale_grad += x_residual_scaled_grad * (x - (scalar_t)n);
              params_grad_a[c][n + 1] += y_grad * x_residual_scaled;
170
              y_vals_grad_a[c][n] += y_grad;
Daniel Povey's avatar
Daniel Povey committed
171
172
173
              // backprop for:  x = input * inv_scale + K,
              inv_scale_grad += x_grad * input;
              input_grad_a[b][c][t] = x_grad * inv_scale;
174
            }
175
176
177
178
            // Do the backprop for:
            //    scale = exp(params_a[c][0]);
            //    inv_scale = exp(-params_a[c][0]);
            params_grad_a[c][0] += (scale * scale_grad - inv_scale * inv_scale_grad);
179
          }
Daniel Povey's avatar
Daniel Povey committed
180
181
182
        }
        // Now do the backprop for the loop above where we set y_vals_a.
        for (int c = 0; c < C; c++) {
183
184
          scalar_t scale = exp(params_a[c][0]),
              scale_grad = 0.0,
Daniel Povey's avatar
Daniel Povey committed
185
              sum_negative_grad = y_vals_grad_a[c][0],   // backprop for: y_vals_a[c][0] = sum_negative
Daniel Povey's avatar
Daniel Povey committed
186
187
              sum_positive_grad = 0.0;
          for (int i = K - 1; i >= 0; i--) {
Daniel Povey's avatar
Daniel Povey committed
188
189
            // backprop for: sum_negative -= params_a[c][K - i] * scale;
            params_grad_a[c][K - i] -= sum_negative_grad * scale;
190
            // backprop for: sum_positive += params_a[c][1 + K + i] * scale;
Daniel Povey's avatar
Daniel Povey committed
191
192
193
194
195
196
197
198
            params_grad_a[c][1 + K + i] += sum_positive_grad * scale;
            // .. and the contributions to scale_grad for the 2 expressions above..
            scale_grad += (sum_positive_grad * params_a[c][1 + K + i] -
                           sum_negative_grad * params_a[c][K - i]);
            // backprop for:  y_vals_a[c][K - i] = sum_negative
            sum_negative_grad += y_vals_grad_a[c][K - i];
            // backprop for: y_vals_a[c][K + i] = sum_positive
            sum_positive_grad += y_vals_grad_a[c][K + i];
Daniel Povey's avatar
Daniel Povey committed
199
          }
200
201
          // Backprop for: scale = exp(params_a[c][0]),
          params_grad_a[c][0] += scale * scale_grad;
Daniel Povey's avatar
Daniel Povey committed
202
203
204
        }
      }));
  return std::vector<torch::Tensor>({input_grad, params_grad});
205
206
207
208
209
210
211
212
213
}




PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("learned_nonlin_cpu", &learned_nonlin_cpu, "Integrated convolution forward function (CPU)");
  m.def("learned_nonlin_backward_cpu", &learned_nonlin_backward_cpu, "Integrated convolution backward function (CPU)");
}