syncbn.cpp 5.2 KB
Newer Older
Natalia Gimelshein's avatar
Natalia Gimelshein committed
1
#include <torch/extension.h>
jjsjann123's avatar
jjsjann123 committed
2
3
4
5
#include <ATen/ATen.h>

#include <vector>

Jie's avatar
Jie committed
6
// returns {mean,biased_var}
jjsjann123's avatar
jjsjann123 committed
7
8
9
10
// implemented using welford 
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);

// reduces array of mean/var across processes
Jie's avatar
Jie committed
11
// returns global {mean,inv_std,biased_var}
jjsjann123's avatar
jjsjann123 committed
12
// implemented using welford 
Jie's avatar
Jie committed
13
14
15
16
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
                                              const at::Tensor var_biased_feature_nodes,
                                              int numel,
                                              const float eps);
jjsjann123's avatar
jjsjann123 committed
17
18
19

// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
Jie's avatar
Jie committed
20
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
jjsjann123's avatar
jjsjann123 committed
21
22
at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
                                  const at::Tensor mean,
Jie's avatar
Jie committed
23
                                  const at::Tensor inv_std,
jjsjann123's avatar
jjsjann123 committed
24
                                  const at::Tensor weight,
Jie's avatar
Jie committed
25
                                  const at::Tensor shift);
jjsjann123's avatar
jjsjann123 committed
26
27
28

// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
Jie's avatar
Jie committed
29
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
jjsjann123's avatar
jjsjann123 committed
30
31
32
33
// implemented using kahan summation
std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
                                           const at::Tensor input,
                                           const at::Tensor mean,
Jie's avatar
Jie committed
34
35
                                           const at::Tensor inv_std,
                                           const at::Tensor weight);
jjsjann123's avatar
jjsjann123 committed
36
37
38

// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
Jie's avatar
Jie committed
39
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
jjsjann123's avatar
jjsjann123 committed
40
41
42
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
                                   const at::Tensor input,
                                   const at::Tensor mean,
Jie's avatar
Jie committed
43
                                   const at::Tensor inv_std,
jjsjann123's avatar
jjsjann123 committed
44
45
                                   const at::Tensor weight,
                                   const at::Tensor mean_dy,
Jie's avatar
Jie committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
                                   const at::Tensor mean_dy_xmu);

// returns {mean, biased_var}
// implemented using welford 
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);

// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
                                         const at::Tensor mean,
                                         const at::Tensor inv_std,
                                         const at::Tensor weight,
                                         const at::Tensor shift);

// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
                                              const at::Tensor input,
                                              const at::Tensor mean,
                                              const at::Tensor inv_std,
                                              const at::Tensor weight);

// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
                                          const at::Tensor input,
                                          const at::Tensor mean,
                                          const at::Tensor inv_std,
                                          const at::Tensor weight,
                                          const at::Tensor mean_dy,
                                          const at::Tensor mean_dy_xmu);
jjsjann123's avatar
jjsjann123 committed
84
85
86
87
88

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
  m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
  m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward");
Jie's avatar
Jie committed
89
  m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad");
jjsjann123's avatar
jjsjann123 committed
90
  m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad");
Jie's avatar
Jie committed
91
92
93
94
  m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc");
  m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
  m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
  m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
jjsjann123's avatar
jjsjann123 committed
95
}