Commit c33d0d64 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Add comments to explain inv-variance in batchnorm forward and backward

parent d2000114
...@@ -529,6 +529,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -529,6 +529,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance)
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) = welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]); type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]);
......
...@@ -359,7 +359,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -359,7 +359,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
welford_count_thread_buf(I)); welford_count_thread_buf(I));
}); });
// calculate inv-variance from variance, stored in place // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) = welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0) / sqrt(welford_var_thread_buf[I] + epsilon); type_convert<AccDataType>(1.0) / sqrt(welford_var_thread_buf[I] + epsilon);
......
...@@ -410,6 +410,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -410,6 +410,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
}); });
// calculate inv-variance as 1/sqrt(epsilon+variance)
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
inv_var_thread_buf(I) = inv_var_thread_buf(I) =
type_convert<AccDataType>(1.0) / sqrt(var_thread_buf[I] + epsilon); type_convert<AccDataType>(1.0) / sqrt(var_thread_buf[I] + epsilon);
......
...@@ -441,6 +441,7 @@ struct GridwiseBatchNormForwardWithBlockwiseWelford ...@@ -441,6 +441,7 @@ struct GridwiseBatchNormForwardWithBlockwiseWelford
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
var_thread_buf(I) = var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]); type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment