Commit 7d114e80 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Add comments to explain the implementation of batchnorm-backward

parent 3ca7feeb
...@@ -149,6 +149,14 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -149,6 +149,14 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// Blockwise BatchNorm Backward
// Input: x, dy, scale, savedMean and savedInvVar (optional), reduce_size
// Output: dx, dscale, dbias
// Step 1: calculate to get mean and invVariance using welford method (if savedMean and savedInvVar not available)
// Step 2: reduce on dy and dy * (x - mean) * invVariance to get dbias and dscale respectively
// Step 3: calculate 1/reduce_size * invVariance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * invVariance)) to get dx elementwise-ly
// clang-format on
__device__ static void Run(const XYGridDesc_M_K x_grid_desc_m_k, __device__ static void Run(const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment