Commit 223a47e9 authored by jiej's avatar jiej
Browse files

allowing syncBN to run with affine = False

parent aed3086a
......@@ -21,8 +21,8 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift);
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
......@@ -32,7 +32,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight);
const at::optional<at::Tensor> weight);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
......@@ -41,7 +41,7 @@ at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
......@@ -57,8 +57,8 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift);
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
......@@ -68,7 +68,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight);
const at::optional<at::Tensor> weight);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
......@@ -78,7 +78,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment