Commit db0a27ad authored by rocking's avatar rocking
Browse files

Fix bug of blockwise welford for first kernel

parent e89422a8
...@@ -857,12 +857,15 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -857,12 +857,15 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
using BlockwiseWelford = BlockwiseWelford<AccDataType, using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize, BlockSize,
PostShuffleThreadClusterSize_M_N, PostShuffleThreadClusterSize_M_N,
Sequence<0, 1>, Sequence<1, 0>,
false>; false>;
constexpr int num_shuffleM = constexpr int num_shuffleM =
MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl); MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl);
constexpr int num_shuffleN =
NPerBlock / (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl);
using mean_var_vgpr_type = using mean_var_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>( decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize())); thread_welford_dst_desc_m.GetElementSpaceSize()));
...@@ -878,7 +881,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -878,7 +881,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
static_for<0, num_shuffleM, 1>{}([&](auto i) { static_for<0, num_shuffleM, 1>{}([&](auto i) {
// TODO - padding // TODO - padding
threadwise_welfords(i).max_count_ = PostShuffleThreadSliceSize_N; threadwise_welfords(i).max_count_ = PostShuffleThreadSliceSize_N * num_shuffleN;
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>( mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()); thread_welford_dst_desc_m.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment