Commit 27f8c64b authored by rocking's avatar rocking
Browse files

Add comment

parent e8ded1e7
......@@ -330,7 +330,12 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
kGridSize_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_);
gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize) * kGridSize_;
numMeanVarCountIteration_ = math::integer_divide_ceil(kGridSize_, KThreadClusterSize);
// We do not use vector load for mean, var and count
static constexpr index_t K_MeanVarCountBlockTileSize = KThreadClusterSize;
numMeanVarCountIteration_ =
math::integer_divide_ceil(kGridSize_, K_MeanVarCountBlockTileSize);
x_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, xStrides_, kGridSize_, numBlockTileIteration_);
......@@ -347,12 +352,14 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
kGridSize_);
kernel2_mean_var_grid_desc_m_kblock_ =
MakeMeanVarDescriptor_M_K<Sequence<true, true>, M_BlockTileSize, K_BlockTileSize>(
MRaw_, kGridSize_);
MakeMeanVarDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
kernel2_count_grid_desc_m_kblock_ =
MakeCountDescriptor_M_K<Sequence<true, true>, M_BlockTileSize, K_BlockTileSize>(
MRaw_, kGridSize_);
MakeCountDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
}
ComputeDataType epsilon_;
......
......@@ -224,6 +224,7 @@ struct GridwiseNormalizationSplitK1st
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
// The value of count is same for all I
if constexpr(I == MThreadSliceSize - 1)
welford_count = count;
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment