Commit 69b925b9 authored by rocking's avatar rocking
Browse files

Merge the mean and var threadwise copy

parent f278b2a5
...@@ -1038,27 +1038,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1038,27 +1038,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
constexpr int shuffleMPerBlock = constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto mean_var_count_thread_copy_index = make_multi_index(
AccDataType, block_work_idx[I0], // mblock
EMeanVarDataType, shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock
decltype(thread_welford_desc_I_m_I), block_work_idx[I1]); // nblock
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
false>{
mean_var_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
auto var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, AccDataType,
EMeanVarDataType, EMeanVarDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
...@@ -1070,26 +1055,26 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1070,26 +1055,26 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{ true>{mean_var_grid_desc_mblock_mperblock_nblock,
mean_var_grid_desc_mblock_mperblock_nblock, mean_var_count_thread_copy_index,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, mean_var_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
mean_thread_buf, mean_thread_buf,
mean_var_grid_desc_mblock_mperblock_nblock, mean_var_grid_desc_mblock_mperblock_nblock,
mean_grid_buf); mean_grid_buf); // write mean
var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, mean_var_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
var_thread_buf, var_thread_buf,
mean_var_grid_desc_mblock_mperblock_nblock, mean_var_grid_desc_mblock_mperblock_nblock,
var_grid_buf); var_grid_buf); // write variance
// Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need
// to be writed.
if(i == 0 && block_work_idx[I0] == 0 && if(i == 0 && block_work_idx[I0] == 0 &&
post_shuffle_thread_data_idx_begin[I0] == 0) post_shuffle_thread_data_idx_begin[I0] == 0)
{ {
...@@ -1106,11 +1091,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1106,11 +1091,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{count_grid_desc_mblock_mperblock_nblock, false>{count_grid_desc_mblock_mperblock_nblock,
make_multi_index( mean_var_count_thread_copy_index,
block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
count_thread_copy_vgpr_to_global.Run( count_thread_copy_vgpr_to_global.Run(
...@@ -1118,7 +1099,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1118,7 +1099,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
count_thread_buf, count_thread_buf,
count_grid_desc_mblock_mperblock_nblock, count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf); welford_count_grid_buf); // write count
} }
} }
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment