Commit 213187f6 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Add/update the comments for batchnorm-backward kernels

parent a5d2bb15
...@@ -140,6 +140,11 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -140,6 +140,11 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Two of the steps of Multiblock BatchNorm Backward
// Step 1: Second half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// Step 2: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k, const XYGridDesc_M_K& dy_grid_desc_m_k,
const XYGridDesc_M_K& dx_grid_desc_m_k, const XYGridDesc_M_K& dx_grid_desc_m_k,
...@@ -208,7 +213,9 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -208,7 +213,9 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// Step 1: do final reduction for scale_diff and bias_diff and output // clang-format off
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
auto threadwise_dscale_load_m_k = auto threadwise_dscale_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType, ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
...@@ -340,8 +347,9 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -340,8 +347,9 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
bias_grid_desc_m, bias_grid_desc_m,
dbias_global_buf); dbias_global_buf);
// Step 2: calculate dx = 1/N * invVar * scale * (N * dy - biasDiff - scaleDiff * (x - mean) // clang-format off
// * invVar) and output // Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
// clang-format on
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration; const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
......
...@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward.
// clang-format on
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
......
...@@ -145,6 +145,11 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -145,6 +145,11 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Two of the steps of Multiblock BatchNorm Backward
// Step 1: Second half of Welford method to calculate mean and variance, as well as getting inv-variance = 1/sqrt(epsilon+variance)
// Step 2: First half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k, const XYGridDesc_M_K& dy_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
...@@ -196,7 +201,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -196,7 +201,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
dy_thread_buf; dy_thread_buf;
// buffer of values of dy * (x-mean) * invVariance, used as input of Blockwise reduction // buffer of values of dy * (x-mean) * inv-variance, used as input of Blockwise reduction
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
tmp1_thread_buf; tmp1_thread_buf;
...@@ -226,8 +231,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -226,8 +231,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// Step 1: load existing mean and inv-variance do final welford reduction on mean and // clang-format off
// variance // Step 1: load existing mean and inv-variance, or do final welford reduction on mean and variance as well as get inv-variance = 1/sqrt(epsilon+variance)
// clang-format on
if(haveSavedMeanInvVar) if(haveSavedMeanInvVar)
{ {
...@@ -451,7 +457,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -451,7 +457,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0); reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0);
}); });
// Step 2: do first-half reduction on dy and dy * (x-mean) * inv-variance // clang-format off
// Step 2: first-half of reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles) for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
{ {
......
...@@ -153,9 +153,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -153,9 +153,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
// Blockwise BatchNorm Backward // Blockwise BatchNorm Backward
// Input: x, dy, scale, savedMean and savedInvVar (optional), reduce_size // Input: x, dy, scale, savedMean and savedInvVar (optional), reduce_size
// Output: dx, dscale, dbias // Output: dx, dscale, dbias
// Step 1: calculate to get mean and invVariance using welford method (if savedMean and savedInvVar not available) // Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
// Step 2: reduce on dy and dy * (x - mean) * invVariance to get dbias and dscale respectively // Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
// Step 3: calculate 1/reduce_size * invVariance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * invVariance)) to get dx elementwise-ly // Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on // clang-format on
__device__ static void Run(const XYGridDesc_M_K x_grid_desc_m_k, __device__ static void Run(const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k,
...@@ -344,6 +344,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -344,6 +344,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize()); p_dbias, bias_grid_desc_m.GetElementSpaceSize());
// clang-format off
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
// clang-format on
if(haveSavedMeanInvVar) if(haveSavedMeanInvVar)
{ {
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -423,6 +427,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -423,6 +427,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
}; };
// clang-format off
// Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
// clang-format on
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
dscale_thread_buf(I) = type_convert<AccDataType>(0); dscale_thread_buf(I) = type_convert<AccDataType>(0);
dbias_thread_buf(I) = type_convert<AccDataType>(0); dbias_thread_buf(I) = type_convert<AccDataType>(0);
...@@ -482,6 +490,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -482,6 +490,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
dbias_global_buf); dbias_global_buf);
}; };
// clang-format off
// Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
threadwise_scale_load.Run(scale_grid_desc_m, threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_buf, scale_global_buf,
thread_buffer_desc_m, thread_buffer_desc_m,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment