"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "97f133d2b302de5287421c52ead628cae48c3b82"
Commit d5728dd3 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Improve the expression calculation for performance

parent de6aad06
...@@ -508,6 +508,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -508,6 +508,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
dy_thread_buf); dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType multiplier = type_convert<AccDataType>(1.0) /
type_convert<AccDataType>(reduce_size) *
inv_var_thread_buf[iM] * scale_thread_buf[iM];
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
...@@ -518,8 +522,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -518,8 +522,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
AccDataType tmpVal = norm_x * dscale_thread_buf[iM]; AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
dx_thread_buf(Number<offset>{}) = dx_thread_buf(Number<offset>{}) =
type_convert<AccDataType>(1.0) / type_convert<AccDataType>(reduce_size) * multiplier *
inv_var_thread_buf[iM] * scale_thread_buf[iM] *
(type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] - (type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
dbias_thread_buf[iM] - tmpVal); dbias_thread_buf[iM] - tmpVal);
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment