Commit 27f8c64b authored by rocking's avatar rocking
Browse files

Add comment

parent e8ded1e7
...@@ -330,7 +330,12 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -330,7 +330,12 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
kGridSize_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); kGridSize_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_);
gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize) * kGridSize_; gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize) * kGridSize_;
numMeanVarCountIteration_ = math::integer_divide_ceil(kGridSize_, KThreadClusterSize);
// We do not use vector load for mean, var and count
static constexpr index_t K_MeanVarCountBlockTileSize = KThreadClusterSize;
numMeanVarCountIteration_ =
math::integer_divide_ceil(kGridSize_, K_MeanVarCountBlockTileSize);
x_grid_desc_m_k_ = x_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, xStrides_, kGridSize_, numBlockTileIteration_); MakeSrc2dDescriptor(Lengths_, xStrides_, kGridSize_, numBlockTileIteration_);
...@@ -347,12 +352,14 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -347,12 +352,14 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
kGridSize_); kGridSize_);
kernel2_mean_var_grid_desc_m_kblock_ = kernel2_mean_var_grid_desc_m_kblock_ =
MakeMeanVarDescriptor_M_K<Sequence<true, true>, M_BlockTileSize, K_BlockTileSize>( MakeMeanVarDescriptor_M_K<Sequence<true, true>,
MRaw_, kGridSize_); M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
kernel2_count_grid_desc_m_kblock_ = kernel2_count_grid_desc_m_kblock_ =
MakeCountDescriptor_M_K<Sequence<true, true>, M_BlockTileSize, K_BlockTileSize>( MakeCountDescriptor_M_K<Sequence<true, true>,
MRaw_, kGridSize_); M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
} }
ComputeDataType epsilon_; ComputeDataType epsilon_;
......
...@@ -224,6 +224,7 @@ struct GridwiseNormalizationSplitK1st ...@@ -224,6 +224,7 @@ struct GridwiseNormalizationSplitK1st
int count = threadwise_welford.cur_count_; int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
// The value of count is same for all I
if constexpr(I == MThreadSliceSize - 1) if constexpr(I == MThreadSliceSize - 1)
welford_count = count; welford_count = count;
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment