syncbn.cpp 6.11 KB
Newer Older
Natalia Gimelshein's avatar
Natalia Gimelshein committed
1
#include <torch/extension.h>
jjsjann123's avatar
jjsjann123 committed
2
3
4
5
#include <ATen/ATen.h>

#include <vector>

Jie's avatar
Jie committed
6
// returns {mean,biased_var}
jjsjann123's avatar
jjsjann123 committed
7
8
9
10
// implemented using welford 
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);

// reduces array of mean/var across processes
Jie's avatar
Jie committed
11
// returns global {mean,inv_std,biased_var}
jjsjann123's avatar
jjsjann123 committed
12
// implemented using welford 
Jie's avatar
Jie committed
13
14
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
                                              const at::Tensor var_biased_feature_nodes,
jjsjann123's avatar
jjsjann123 committed
15
                                              const at::Tensor numel,
Jie's avatar
Jie committed
16
                                              const float eps);
jjsjann123's avatar
jjsjann123 committed
17
18
19

// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
Jie's avatar
Jie committed
20
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
jjsjann123's avatar
jjsjann123 committed
21
22
at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
                                  const at::Tensor mean,
Jie's avatar
Jie committed
23
                                  const at::Tensor inv_std,
24
25
                                  const at::optional<at::Tensor> weight,
                                  const at::optional<at::Tensor> shift);
jjsjann123's avatar
jjsjann123 committed
26

jjsjann123's avatar
jjsjann123 committed
27
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
jjsjann123's avatar
jjsjann123 committed
28
// grad_output/input should have identical data type;
Jie's avatar
Jie committed
29
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
jjsjann123's avatar
jjsjann123 committed
30
31
32
33
// implemented using kahan summation
std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
                                           const at::Tensor input,
                                           const at::Tensor mean,
Jie's avatar
Jie committed
34
                                           const at::Tensor inv_std,
35
                                           const at::optional<at::Tensor> weight);
jjsjann123's avatar
jjsjann123 committed
36
37
38

// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
jjsjann123's avatar
jjsjann123 committed
39
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
jjsjann123's avatar
jjsjann123 committed
40
41
42
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
                                   const at::Tensor input,
                                   const at::Tensor mean,
Jie's avatar
Jie committed
43
                                   const at::Tensor inv_std,
44
                                   const at::optional<at::Tensor> weight,
jjsjann123's avatar
jjsjann123 committed
45
46
47
                                   const at::Tensor sum_dy,
                                   const at::Tensor sum_dy_xmu,
                                   const at::Tensor count);
Jie's avatar
Jie committed
48
49
50
51
52
53
54
55
56
57
58

// returns {mean, biased_var}
// implemented using welford 
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);

// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
jjsjann123's avatar
jjsjann123 committed
59
                                         const at::optional<at::Tensor> z,
Jie's avatar
Jie committed
60
61
                                         const at::Tensor mean,
                                         const at::Tensor inv_std,
62
                                         const at::optional<at::Tensor> weight,
jjsjann123's avatar
jjsjann123 committed
63
64
                                         const at::optional<at::Tensor> shift,
                                         const bool fuse_relu);
Jie's avatar
Jie committed
65

jjsjann123's avatar
jjsjann123 committed
66
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
Jie's avatar
Jie committed
67
68
69
70
71
72
73
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
                                              const at::Tensor input,
                                              const at::Tensor mean,
                                              const at::Tensor inv_std,
74
                                              const at::optional<at::Tensor> weight);
Jie's avatar
Jie committed
75
76
77

// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
jjsjann123's avatar
jjsjann123 committed
78
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
Jie's avatar
Jie committed
79
80
81
82
83
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
                                          const at::Tensor input,
                                          const at::Tensor mean,
                                          const at::Tensor inv_std,
84
                                          const at::optional<at::Tensor> weight,
jjsjann123's avatar
jjsjann123 committed
85
86
87
                                          const at::Tensor sum_dy,
                                          const at::Tensor sum_dy_xmu,
                                          const at::Tensor count);
jjsjann123's avatar
jjsjann123 committed
88

jjsjann123's avatar
jjsjann123 committed
89
90
91
92
93
94
95
96
97
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
                                     const at::Tensor input,
                                     const at::optional<at::Tensor> z,
                                     const at::Tensor mean,
                                     const at::Tensor inv_std,
                                     const at::optional<at::Tensor> weight,
                                     const at::optional<at::Tensor> shift);


jjsjann123's avatar
jjsjann123 committed
98
99
100
101
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
  m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
  m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward");
Jie's avatar
Jie committed
102
  m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad");
jjsjann123's avatar
jjsjann123 committed
103
  m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad");
Jie's avatar
Jie committed
104
105
106
107
  m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc");
  m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
  m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
  m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
jjsjann123's avatar
jjsjann123 committed
108
  m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last");
jjsjann123's avatar
jjsjann123 committed
109
}