Commit 3ac68bcc authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Remove common expression out of the loop in reference_batchnorm_backward_nhwc_c

parent 658c55ff
......@@ -183,6 +183,10 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
arg.p_dscale_[offset_C] = type_convert<ScaleDataType>(dscale);
arg.p_dbias_[offset_C] = type_convert<BiasDataType>(dbias);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]);
AccDataType multiplier =
type_convert<AccDataType>(1.0f) / reduceSize * invVar * scale;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/nhw * inv-variance * scale * (nhw * dy - dbias - tmp)
for(index_t iN = 0; iN < arg.n_; iN++)
......@@ -201,14 +205,12 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]);
arg.dy_elementwise_op_(dy, dy);
AccDataType tmpVal = norm_x * dscale;
AccDataType dx = type_convert<AccDataType>(1.0f) / reduceSize * invVar *
scale * (reduceSize * dy - dbias - tmpVal);
AccDataType dx = multiplier * (reduceSize * dy - dbias - tmpVal);
arg.p_dx_[offset] = type_convert<XDataType>(dx);
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment