Commit 213187f6 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Add/update the comments for batchnorm-backward kernels

parent a5d2bb15
......@@ -140,6 +140,11 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Two of the steps of Multiblock BatchNorm Backward
// Step 1: Second half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// Step 2: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k,
const XYGridDesc_M_K& dx_grid_desc_m_k,
......@@ -208,7 +213,9 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// Step 1: do final reduction for scale_diff and bias_diff and output
// clang-format off
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
auto threadwise_dscale_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
......@@ -340,8 +347,9 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
bias_grid_desc_m,
dbias_global_buf);
// Step 2: calculate dx = 1/N * invVar * scale * (N * dy - biasDiff - scaleDiff * (x - mean)
// * invVar) and output
// clang-format off
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
// clang-format on
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
......
......@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward.
// clang-format on
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
......
......@@ -145,6 +145,11 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Two of the steps of Multiblock BatchNorm Backward
// Step 1: Second half of Welford method to calculate mean and variance, as well as getting inv-variance = 1/sqrt(epsilon+variance)
// Step 2: First half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
......@@ -196,7 +201,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dy_thread_buf;
// buffer of values of dy * (x-mean) * invVariance, used as input of Blockwise reduction
// buffer of values of dy * (x-mean) * inv-variance, used as input of Blockwise reduction
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
tmp1_thread_buf;
......@@ -226,8 +231,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// Step 1: load existing mean and inv-variance do final welford reduction on mean and
// variance
// clang-format off
// Step 1: load existing mean and inv-variance, or do final welford reduction on mean and variance as well as get inv-variance = 1/sqrt(epsilon+variance)
// clang-format on
if(haveSavedMeanInvVar)
{
......@@ -451,7 +457,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0);
});
// Step 2: do first-half reduction on dy and dy * (x-mean) * inv-variance
// clang-format off
// Step 2: first-half of reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
{
......
......@@ -153,9 +153,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
// Blockwise BatchNorm Backward
// Input: x, dy, scale, savedMean and savedInvVar (optional), reduce_size
// Output: dx, dscale, dbias
// Step 1: calculate to get mean and invVariance using welford method (if savedMean and savedInvVar not available)
// Step 2: reduce on dy and dy * (x - mean) * invVariance to get dbias and dscale respectively
// Step 3: calculate 1/reduce_size * invVariance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * invVariance)) to get dx elementwise-ly
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
// Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
// Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
__device__ static void Run(const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
......@@ -344,6 +344,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
// clang-format off
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
// clang-format on
if(haveSavedMeanInvVar)
{
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -423,6 +427,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
thread_k_cluster_id * KThreadSliceSize));
};
// clang-format off
// Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
// clang-format on
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
dscale_thread_buf(I) = type_convert<AccDataType>(0);
dbias_thread_buf(I) = type_convert<AccDataType>(0);
......@@ -482,6 +490,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
dbias_global_buf);
};
// clang-format off
// Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_buf,
thread_buffer_desc_m,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment