syncbn_kernel.cu 17.5 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
#include <torch/extension.h>
Hang Zhang's avatar
Hang Zhang committed
2
3
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
Hang Zhang's avatar
Hang Zhang committed
4
#include <vector>
Hang Zhang's avatar
Hang Zhang committed
5
6
7
8
9
10
11
12
13

#include "common.h"
#include "device_tensor.h"

namespace {

template <typename DType, typename Acctype, typename DeviceTensor3>
struct GradOp {
  __device__ GradOp(Acctype m, const DeviceTensor3 i, const DeviceTensor3 g)
Hang Zhang's avatar
Hang Zhang committed
14
    : beta(m), output(i), gradOutput(g) {}
Hang Zhang's avatar
Hang Zhang committed
15
16
  __device__ __forceinline__ Float2<DType, Acctype> operator()(int batch, int plane, int n) {
    DType g = gradOutput[batch][plane][n];
Hang Zhang's avatar
Hang Zhang committed
17
    DType c = ScalarConvert<Acctype, DType>::to(output[batch][plane][n] - beta);
Hang Zhang's avatar
Hang Zhang committed
18
19
    return Float2<DType, Acctype>(g, g * c);
  }
Hang Zhang's avatar
Hang Zhang committed
20
21
  const Acctype beta;
  const DeviceTensor3 output;
Hang Zhang's avatar
Hang Zhang committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  const DeviceTensor3 gradOutput;
};

template <typename DType, typename Acctype>
struct SumOp {
  __device__ SumOp(DeviceTensor<DType, 3> i) : input(i){}
  __device__ __forceinline__ Float2<DType, Acctype> operator()(int batch, int plane, int n) {
    DType g = input[batch][plane][n];
    return Float2<DType, Acctype>(g, g * g);
  }
  DType mean;
  DeviceTensor<DType, 3> input;
};

// Sum across (batch, x/y/z) applying Op() pointwise
template<typename T, typename Op, typename DeviceTensor3>
__device__ T reduce(Op op, DeviceTensor3 tensor, int plane) {
  T sum = (T)0;
  for (int batch = 0; batch < tensor.getSize(0); ++batch) {
    for (int x = threadIdx.x; x < tensor.getSize(2); x += blockDim.x) {
      sum += op(batch, plane, x);
    }
  }

  // sum over NumThreads within a warp
  sum = warpSum(sum);

  // 'transpose', and reduce within warp again
  __shared__ T shared[32];
  __syncthreads();
  if (threadIdx.x % WARP_SIZE == 0) {
    shared[threadIdx.x / WARP_SIZE] = sum;
  }
  if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
    // zero out the other entries in shared
    shared[threadIdx.x] = (T)0;
  }
  __syncthreads();
  if (threadIdx.x / WARP_SIZE == 0) {
    sum = warpSum(shared[threadIdx.x]);
    if (threadIdx.x == 0) {
      shared[0] = sum;
    }
  }
  __syncthreads();

  // Everyone picks it up, should be broadcast into the whole gradInput
  return shared[0];
}

template <typename DType>
__global__ void BatchNorm_Forward_kernel (
  DeviceTensor<DType, 3> output,
  DeviceTensor<DType, 3> input,
  DeviceTensor<DType, 1> mean,
  DeviceTensor<DType, 1> std,
  DeviceTensor<DType, 1> gamma,
  DeviceTensor<DType, 1> beta) {
  int c = blockIdx.x;
  /* main operation */ 
  for (int b = 0; b < input.getSize(0); ++b) {
    for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) {
      DType inp = input[b][c][x];
      output[b][c][x] = gamma[c] * (inp - mean[c]) /
        std[c] + beta[c];
    }
  }
}

Hang Zhang's avatar
Hang Zhang committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
template <typename DType>
__global__ void BatchNorm_Forward_Inp_kernel (
  DeviceTensor<DType, 3> input,
  DeviceTensor<DType, 1> mean,
  DeviceTensor<DType, 1> std,
  DeviceTensor<DType, 1> gamma,
  DeviceTensor<DType, 1> beta) {
  int c = blockIdx.x;
  /* main operation */ 
  for (int b = 0; b < input.getSize(0); ++b) {
    for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) {
      DType inp = input[b][c][x];
      input[b][c][x] = gamma[c] * (inp - mean[c]) /
        std[c] + beta[c];
    }
  }
}

template <typename DType>
__global__ void BatchNorm_Backward_Inp_kernel (
    DeviceTensor<DType, 3> gradoutput,
    DeviceTensor<DType, 3> output,
    DeviceTensor<DType, 3> gradinput,
    DeviceTensor<DType, 1> gradgamma,
    DeviceTensor<DType, 1> gradbeta,
    DeviceTensor<DType, 1> mean,
    DeviceTensor<DType, 1> std,
    DeviceTensor<DType, 1> gamma,
    DeviceTensor<DType, 1> beta,
    DeviceTensor<DType, 1> gradEx, 
    DeviceTensor<DType, 1> gradExs) {
  /* declarations of the variables */
  /* Get the index and channels */ 
  int c = blockIdx.x; 
  /* main operation */ 
  GradOp<DType, DType, DeviceTensor<DType, 3>> g(beta[c], output, gradoutput);
  Float2<DType, DType> res = reduce<Float2<DType, DType>,
    GradOp<DType, DType, DeviceTensor<DType, 3>>,
    DeviceTensor<DType, 3>>(g, gradoutput, c);
  DType gradOutputSum = res.v1;
  DType dotP = res.v2;
  DType invstd = DType(1.0) / std[c];
  DType gradScale = invstd * gamma[c];
  if (threadIdx.x == 0) {
    gradEx[c] = - gradOutputSum * gradScale + mean[c] * invstd * invstd * dotP;
    gradExs[c]  = - 0.5 * invstd * invstd * dotP;
  }
  if (gradinput.numElements() > 0) {
    for (int batch = 0; batch < gradoutput.getSize(0); ++batch) {
      for (int x = threadIdx.x; x < gradoutput.getSize(2); x += blockDim.x) {
        gradinput[batch][c][x] = gradoutput[batch][c][x] * gradScale;
      }
    }
  }
  if (gradgamma.numElements() > 0) {
    if (threadIdx.x == 0) {
      gradgamma[c] += dotP / gamma[c];
    }
  }
  if (gradbeta.numElements() > 0) {
    if (threadIdx.x == 0) {
      gradbeta[c] += gradOutputSum;
    }
  }
}

Hang Zhang's avatar
Hang Zhang committed
157
158
159
160
161
162
163
164
165
166
167
template <typename DType>
__global__ void BatchNorm_Backward_kernel (
    DeviceTensor<DType, 3> gradoutput,
    DeviceTensor<DType, 3> input,
    DeviceTensor<DType, 3> gradinput,
    DeviceTensor<DType, 1> gradgamma,
    DeviceTensor<DType, 1> gradbeta,
    DeviceTensor<DType, 1> mean,
    DeviceTensor<DType, 1> std,
    DeviceTensor<DType, 1> gamma,
    DeviceTensor<DType, 1> beta,
Hang Zhang's avatar
Hang Zhang committed
168
169
    DeviceTensor<DType, 1> gradEx, 
    DeviceTensor<DType, 1> gradExs) {
Hang Zhang's avatar
Hang Zhang committed
170
171
172
173
174
175
176
177
178
179
180
181
  /* declarations of the variables */
  /* Get the index and channels */ 
  int c = blockIdx.x; 
  /* main operation */ 
  GradOp<DType, DType, DeviceTensor<DType, 3>> g(mean[c], input, gradoutput);
  Float2<DType, DType> res = reduce<Float2<DType, DType>,
    GradOp<DType, DType, DeviceTensor<DType, 3>>,
    DeviceTensor<DType, 3>>(g, gradoutput, c);
  DType gradOutputSum = res.v1;
  DType dotP = res.v2;
  DType invstd = DType(1.0) / std[c];
  DType gradScale = invstd * gamma[c];
Hang Zhang's avatar
Hang Zhang committed
182
183
184
  if (threadIdx.x == 0) {
    gradEx[c] = - gradOutputSum * gradScale + mean[c] * invstd * invstd * dotP * gradScale;
    gradExs[c]  = - 0.5 * invstd * invstd * dotP * gradScale;
Hang Zhang's avatar
Hang Zhang committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
  }
  if (gradinput.numElements() > 0) {
    for (int batch = 0; batch < gradoutput.getSize(0); ++batch) {
      for (int x = threadIdx.x; x < gradoutput.getSize(2); x += blockDim.x) {
        gradinput[batch][c][x] = gradoutput[batch][c][x] * gradScale;
      }
    }
  }
  if (gradgamma.numElements() > 0) {
    if (threadIdx.x == 0) {
      gradgamma[c] += dotP * invstd;
    }
  }
  if (gradbeta.numElements() > 0) {
    if (threadIdx.x == 0) {
      gradbeta[c] += gradOutputSum;
    }
  }
}


template <typename DType>
Hang Zhang's avatar
Hang Zhang committed
207
__global__ void Expectation_Forward_kernel (
Hang Zhang's avatar
Hang Zhang committed
208
    DeviceTensor<DType, 3> input,
Hang Zhang's avatar
Hang Zhang committed
209
210
211
    DeviceTensor<DType, 1> ex,
    DeviceTensor<DType, 1> exs,
    DType norm) {
Hang Zhang's avatar
Hang Zhang committed
212
213
214
215
216
217
218
219
  int c = blockIdx.x;
  /* main operation */ 
  SumOp<DType, DType> g(input);
  Float2<DType, DType> res = reduce<Float2<DType, DType>,
    SumOp<DType, DType>, DeviceTensor<DType, 3>>(g, input, c);
  DType xsum = res.v1;
  DType xsquare = res.v2;
  if (threadIdx.x == 0) {
Hang Zhang's avatar
Hang Zhang committed
220
221
    ex[c] = xsum * norm;
    exs[c] = xsquare * norm;
Hang Zhang's avatar
Hang Zhang committed
222
223
224
225
  }
}

template <typename DType>
Hang Zhang's avatar
Hang Zhang committed
226
__global__ void Expectation_Backward_kernel (
Hang Zhang's avatar
Hang Zhang committed
227
228
  DeviceTensor<DType, 3> gradInput,
  DeviceTensor<DType, 3> input,
Hang Zhang's avatar
Hang Zhang committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
  DeviceTensor<DType, 1> gradEx,
  DeviceTensor<DType, 1> gradExs,
  DType norm) {
  int c = blockIdx.x;
  /* main operation */ 
  for (int batch = 0; batch < gradInput.getSize(0); ++batch) {
    for (int x = threadIdx.x; x < gradInput.getSize(2); x += blockDim.x) {
      gradInput[batch][c][x] = gradEx[c] * norm + 2 * gradExs[c] *
          input[batch][c][x] * norm;
    }
  }
}

template <typename DType>
__global__ void Expectation_Backward_Inp_kernel (
  DeviceTensor<DType, 3> gradInput,
  DeviceTensor<DType, 3> output,
  DeviceTensor<DType, 1> gradEx,
  DeviceTensor<DType, 1> gradExs,
  DeviceTensor<DType, 1> mean,
  DeviceTensor<DType, 1> std,
  DeviceTensor<DType, 1> gamma,
  DeviceTensor<DType, 1> beta,
  DType norm) {
Hang Zhang's avatar
Hang Zhang committed
253
254
255
  int c = blockIdx.x;
  /* main operation */ 
  for (int batch = 0; batch < gradInput.getSize(0); ++batch) {
Hang Zhang's avatar
Hang Zhang committed
256
257
258
    for (int x = threadIdx.x; x < gradInput.getSize(2); x += blockDim.x) {
      gradInput[batch][c][x] += gradEx[c] * norm + 2 * gradExs[c] *
          ((output[batch][c][x] - beta[c]) / gamma[c] * std[c] + mean[c]) * norm;
Hang Zhang's avatar
Hang Zhang committed
259
    }
Hang Zhang's avatar
Hang Zhang committed
260
  }
Hang Zhang's avatar
Hang Zhang committed
261
262
}

Hang Zhang's avatar
Hang Zhang committed
263
} // namespace
Hang Zhang's avatar
Hang Zhang committed
264
265
266

at::Tensor BatchNorm_Forward_CUDA(
    const at::Tensor input_, 
Hang Zhang's avatar
Hang Zhang committed
267
268
    const at::Tensor ex_,
    const at::Tensor exs_,
Hang Zhang's avatar
Hang Zhang committed
269
    const at::Tensor gamma_,
Hang Zhang's avatar
Hang Zhang committed
270
271
    const at::Tensor beta_,
    float eps) {
Hang Zhang's avatar
Hang Zhang committed
272
  auto output_ = at::zeros_like(input_);
Hang Zhang's avatar
Hang Zhang committed
273
  auto std_ = (exs_ - ex_ * ex_ + eps).sqrt();
Hang Zhang's avatar
Hang Zhang committed
274
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Hang Zhang's avatar
Hang Zhang committed
275
276
277
278
279
280
  dim3 blocks(input_.size(1));
  dim3 threads(getNumThreads(input_.size(2)));
  AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Forward_CUDA", ([&] {
    /* Device tensors */
    DeviceTensor<scalar_t, 3> output = devicetensor<scalar_t, 3>(output_);
    DeviceTensor<scalar_t, 3> input = devicetensor<scalar_t, 3>(input_);
Hang Zhang's avatar
Hang Zhang committed
281
    DeviceTensor<scalar_t, 1> ex = devicetensor<scalar_t, 1>(ex_);
Hang Zhang's avatar
Hang Zhang committed
282
283
284
285
286
    DeviceTensor<scalar_t, 1> std = devicetensor<scalar_t, 1>(std_);
    DeviceTensor<scalar_t, 1> gamma = devicetensor<scalar_t, 1>(gamma_);
    DeviceTensor<scalar_t, 1> beta = devicetensor<scalar_t, 1>(beta_);
    /* kernel function */
    BatchNorm_Forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
Hang Zhang's avatar
Hang Zhang committed
287
        output, input, ex, std, gamma, beta);
Hang Zhang's avatar
Hang Zhang committed
288
289
290
291
292
  }));
  AT_ASSERT(cudaGetLastError() == cudaSuccess);
  return output_;
}

Hang Zhang's avatar
Hang Zhang committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
at::Tensor BatchNorm_Forward_Inp_CUDA(
    const at::Tensor input_, 
    const at::Tensor ex_,
    const at::Tensor exs_,
    const at::Tensor gamma_,
    const at::Tensor beta_,
    float eps) {
  auto std_ = (exs_ - ex_ * ex_ + eps).sqrt();
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  dim3 blocks(input_.size(1));
  dim3 threads(getNumThreads(input_.size(2)));
  AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Forward_CUDA", ([&] {
    /* Device tensors */
    DeviceTensor<scalar_t, 3> input = devicetensor<scalar_t, 3>(input_);
    DeviceTensor<scalar_t, 1> ex = devicetensor<scalar_t, 1>(ex_);
    DeviceTensor<scalar_t, 1> std = devicetensor<scalar_t, 1>(std_);
    DeviceTensor<scalar_t, 1> gamma = devicetensor<scalar_t, 1>(gamma_);
    DeviceTensor<scalar_t, 1> beta = devicetensor<scalar_t, 1>(beta_);
    /* kernel function */
    BatchNorm_Forward_Inp_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
        input, ex, std, gamma, beta);
  }));
  AT_ASSERT(cudaGetLastError() == cudaSuccess);
  return input_;
}


std::vector<at::Tensor> BatchNorm_Inp_Backward_CUDA(
    const at::Tensor gradoutput_,
    const at::Tensor output_,
    const at::Tensor ex_, 
    const at::Tensor exs_,
    const at::Tensor gamma_,
    const at::Tensor beta_,
    float eps) {
  /* outputs*/
  auto std_ = (exs_ - ex_ * ex_ + eps).sqrt();
  auto gradinput_ = at::zeros_like(output_);
  auto gradgamma_ = at::zeros_like(gamma_);
  auto gradbeta_ = at::zeros_like(beta_);
  auto gradEx_ = at::zeros_like(ex_);
  auto gradExs_ = at::zeros_like(std_);
  /* cuda utils*/
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  dim3 blocks(output_.size(1));
  dim3 threads(getNumThreads(output_.size(2)));
  AT_DISPATCH_FLOATING_TYPES(output_.type(), "BatchNorm_Inp_Backward_CUDA", ([&] {
    /* Device tensors */
    DeviceTensor<scalar_t, 3> gradoutput = devicetensor<scalar_t, 3>(gradoutput_);
    DeviceTensor<scalar_t, 3> output = devicetensor<scalar_t, 3>(output_);
    DeviceTensor<scalar_t, 3> gradinput = devicetensor<scalar_t, 3>(gradinput_);
    DeviceTensor<scalar_t, 1> gradgamma = devicetensor<scalar_t, 1>(gradgamma_);
    DeviceTensor<scalar_t, 1> gradbeta = devicetensor<scalar_t, 1>(gradbeta_);
    DeviceTensor<scalar_t, 1> ex = devicetensor<scalar_t, 1>(ex_);
    DeviceTensor<scalar_t, 1> std = devicetensor<scalar_t, 1>(std_);
    DeviceTensor<scalar_t, 1> gamma = devicetensor<scalar_t, 1>(gamma_);
    DeviceTensor<scalar_t, 1> beta = devicetensor<scalar_t, 1>(beta_);
    DeviceTensor<scalar_t, 1> gradEx = devicetensor<scalar_t, 1>(gradEx_);
    DeviceTensor<scalar_t, 1> gradExs = devicetensor<scalar_t, 1>(gradExs_);
    /* kernel function */
    BatchNorm_Backward_Inp_kernel<scalar_t>
      <<<blocks, threads, 0, stream>>>(
      gradoutput, output, gradinput, gradgamma, gradbeta, ex, std, 
      gamma, beta, gradEx, gradExs);
  }));
  AT_ASSERT(cudaGetLastError() == cudaSuccess);
  return {gradinput_, gradEx_, gradExs_, gradgamma_, gradbeta_};
}


Hang Zhang's avatar
Hang Zhang committed
363
364
365
std::vector<at::Tensor> BatchNorm_Backward_CUDA(
    const at::Tensor gradoutput_,
    const at::Tensor input_,
Hang Zhang's avatar
Hang Zhang committed
366
367
    const at::Tensor ex_, 
    const at::Tensor exs_,
Hang Zhang's avatar
Hang Zhang committed
368
    const at::Tensor gamma_,
Hang Zhang's avatar
Hang Zhang committed
369
370
    const at::Tensor beta_,
    float eps) {
Hang Zhang's avatar
Hang Zhang committed
371
  /* outputs*/
Hang Zhang's avatar
Hang Zhang committed
372
373
374
375
376
377
  auto std_ = (exs_ - ex_ * ex_ + eps).sqrt();
  auto gradinput_ = at::zeros_like(input_);
  auto gradgamma_ = at::zeros_like(gamma_);
  auto gradbeta_ = at::zeros_like(beta_);
  auto gradEx_ = at::zeros_like(ex_);
  auto gradExs_ = at::zeros_like(std_);
Hang Zhang's avatar
Hang Zhang committed
378
  /* cuda utils*/
Hang Zhang's avatar
Hang Zhang committed
379
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Hang Zhang's avatar
Hang Zhang committed
380
381
  dim3 blocks(input_.size(1));
  dim3 threads(getNumThreads(input_.size(2)));
Hang Zhang's avatar
Hang Zhang committed
382
  AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Inp_Backward_CUDA", ([&] {
Hang Zhang's avatar
Hang Zhang committed
383
384
385
386
387
388
    /* Device tensors */
    DeviceTensor<scalar_t, 3> gradoutput = devicetensor<scalar_t, 3>(gradoutput_);
    DeviceTensor<scalar_t, 3> input = devicetensor<scalar_t, 3>(input_);
    DeviceTensor<scalar_t, 3> gradinput = devicetensor<scalar_t, 3>(gradinput_);
    DeviceTensor<scalar_t, 1> gradgamma = devicetensor<scalar_t, 1>(gradgamma_);
    DeviceTensor<scalar_t, 1> gradbeta = devicetensor<scalar_t, 1>(gradbeta_);
Hang Zhang's avatar
Hang Zhang committed
389
    DeviceTensor<scalar_t, 1> ex = devicetensor<scalar_t, 1>(ex_);
Hang Zhang's avatar
Hang Zhang committed
390
391
392
    DeviceTensor<scalar_t, 1> std = devicetensor<scalar_t, 1>(std_);
    DeviceTensor<scalar_t, 1> gamma = devicetensor<scalar_t, 1>(gamma_);
    DeviceTensor<scalar_t, 1> beta = devicetensor<scalar_t, 1>(beta_);
Hang Zhang's avatar
Hang Zhang committed
393
394
    DeviceTensor<scalar_t, 1> gradEx = devicetensor<scalar_t, 1>(gradEx_);
    DeviceTensor<scalar_t, 1> gradExs = devicetensor<scalar_t, 1>(gradExs_);
Hang Zhang's avatar
Hang Zhang committed
395
396
397
    /* kernel function */
    BatchNorm_Backward_kernel<scalar_t>
      <<<blocks, threads, 0, stream>>>(
Hang Zhang's avatar
Hang Zhang committed
398
399
      gradoutput, input, gradinput, gradgamma, gradbeta, ex, std, 
      gamma, beta, gradEx, gradExs);
Hang Zhang's avatar
Hang Zhang committed
400
401
  }));
  AT_ASSERT(cudaGetLastError() == cudaSuccess);
Hang Zhang's avatar
Hang Zhang committed
402
  return {gradinput_, gradEx_, gradExs_, gradgamma_, gradbeta_};
Hang Zhang's avatar
Hang Zhang committed
403
404
}

Hang Zhang's avatar
Hang Zhang committed
405
std::vector<at::Tensor> Expectation_Forward_CUDA(
Hang Zhang's avatar
Hang Zhang committed
406
407
    const at::Tensor input_) {
  /* outputs */
Hang Zhang's avatar
Hang Zhang committed
408
409
  auto ex_ = torch::zeros({input_.size(1)}, input_.options());
  auto exs_ = torch::zeros({input_.size(1)}, input_.options());
Hang Zhang's avatar
Hang Zhang committed
410
  /* cuda utils*/
Hang Zhang's avatar
Hang Zhang committed
411
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Hang Zhang's avatar
Hang Zhang committed
412
413
  dim3 blocks(input_.size(1));
  dim3 threads(getNumThreads(input_.size(2)));
Hang Zhang's avatar
Hang Zhang committed
414
  AT_DISPATCH_FLOATING_TYPES(input_.type(), "SumSquare_forward_CUDA", ([&] {
Hang Zhang's avatar
Hang Zhang committed
415
    scalar_t norm = scalar_t(1) / (input_.size(0) * input_.size(2));
Hang Zhang's avatar
Hang Zhang committed
416
417
    /* Device tensors */
    DeviceTensor<scalar_t, 3> input = devicetensor<scalar_t, 3>(input_);
Hang Zhang's avatar
Hang Zhang committed
418
419
    DeviceTensor<scalar_t, 1> ex = devicetensor<scalar_t, 1>(ex_);
    DeviceTensor<scalar_t, 1> exs = devicetensor<scalar_t, 1>(exs_);
Hang Zhang's avatar
Hang Zhang committed
420
    /* kernel function */
Hang Zhang's avatar
Hang Zhang committed
421
422
    Expectation_Forward_kernel<scalar_t>
      <<<blocks, threads, 0, stream>>>(input, ex, exs, norm);
Hang Zhang's avatar
Hang Zhang committed
423
424
  }));
  AT_ASSERT(cudaGetLastError() == cudaSuccess);
Hang Zhang's avatar
Hang Zhang committed
425
  return {ex_, exs_};
Hang Zhang's avatar
Hang Zhang committed
426
427
}

Hang Zhang's avatar
Hang Zhang committed
428
at::Tensor Expectation_Backward_CUDA(
Hang Zhang's avatar
Hang Zhang committed
429
    const at::Tensor input_,
Hang Zhang's avatar
Hang Zhang committed
430
431
    const at::Tensor gradEx_,
    const at::Tensor gradExs_) {
Hang Zhang's avatar
Hang Zhang committed
432
433
434
  /* outputs */
  at::Tensor gradInput_ = at::zeros_like(input_);
  /* cuda utils*/
Hang Zhang's avatar
Hang Zhang committed
435
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Hang Zhang's avatar
Hang Zhang committed
436
437
  dim3 blocks(input_.size(1));
  dim3 threads(getNumThreads(input_.size(2)));
Hang Zhang's avatar
Hang Zhang committed
438
  AT_DISPATCH_FLOATING_TYPES(input_.type(), "SumSquare_Backward_CUDA", ([&] {
Hang Zhang's avatar
Hang Zhang committed
439
    scalar_t norm = scalar_t(1) / (input_.size(0) * input_.size(2));
Hang Zhang's avatar
Hang Zhang committed
440
441
442
    /* Device tensors */
    DeviceTensor<scalar_t, 3> gradInput = devicetensor<scalar_t, 3>(gradInput_);
    DeviceTensor<scalar_t, 3> input = devicetensor<scalar_t, 3>(input_);
Hang Zhang's avatar
Hang Zhang committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    DeviceTensor<scalar_t, 1> gradEx = devicetensor<scalar_t, 1>(gradEx_);
    DeviceTensor<scalar_t, 1> gradExs =devicetensor<scalar_t, 1>(gradExs_);
    /* kernel function */
    Expectation_Backward_kernel<scalar_t>
      <<<blocks, threads, 0, stream>>>(gradInput, input, gradEx, gradExs, norm);
  }));
  AT_ASSERT(cudaGetLastError() == cudaSuccess);
  return gradInput_;
}

at::Tensor Expectation_Inp_Backward_CUDA(
    const at::Tensor gradInput_,
    const at::Tensor output_,
    const at::Tensor gradEx_,
    const at::Tensor gradExs_,
    const at::Tensor ex_, 
    const at::Tensor exs_,
    const at::Tensor gamma_,
    const at::Tensor beta_,
    float eps) {
  /* outputs */
  //auto gradInput_ = at::zeros_like(output_);
  auto std_ = (exs_ - ex_ * ex_ + eps).sqrt();
  /* cuda utils*/
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  dim3 blocks(output_.size(1));
  dim3 threads(getNumThreads(output_.size(2)));
  AT_DISPATCH_FLOATING_TYPES(output_.type(), "SumSquare_Backward_CUDA", ([&] {
    scalar_t norm = scalar_t(1) / (output_.size(0) * output_.size(2));
    /* Device tensors */
    DeviceTensor<scalar_t, 3> gradInput = devicetensor<scalar_t, 3>(gradInput_);
    DeviceTensor<scalar_t, 3> input = devicetensor<scalar_t, 3>(output_);
    DeviceTensor<scalar_t, 1> gradEx = devicetensor<scalar_t, 1>(gradEx_);
    DeviceTensor<scalar_t, 1> gradExs =devicetensor<scalar_t, 1>(gradExs_);
    DeviceTensor<scalar_t, 1> ex = devicetensor<scalar_t, 1>(ex_);
    DeviceTensor<scalar_t, 1> std = devicetensor<scalar_t, 1>(std_);
    DeviceTensor<scalar_t, 1> gamma = devicetensor<scalar_t, 1>(gamma_);
    DeviceTensor<scalar_t, 1> beta = devicetensor<scalar_t, 1>(beta_);
Hang Zhang's avatar
Hang Zhang committed
481
    /* kernel function */
Hang Zhang's avatar
Hang Zhang committed
482
483
484
    Expectation_Backward_Inp_kernel<scalar_t>
      <<<blocks, threads, 0, stream>>>(gradInput, input, gradEx, gradExs,
          ex, std, gamma, beta, norm);
Hang Zhang's avatar
Hang Zhang committed
485
486
487
488
  }));
  AT_ASSERT(cudaGetLastError() == cudaSuccess);
  return gradInput_;
}