Commit bb7294ae authored by rocking's avatar rocking
Browse files

Prevent redundant IO

parent 3df07c27
...@@ -1030,86 +1030,97 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1030,86 +1030,97 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j)); mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
}); });
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed( if(post_shuffle_thread_cluster_idx[I1] == 0)
make_tuple(I1, Number<PostShuffleThreadSliceSize_M>{}, I1)); {
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
constexpr int shuffleMPerBlock = make_tuple(I1, Number<PostShuffleThreadSliceSize_M>{}, I1));
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
constexpr int shuffleMPerBlock =
auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
AccDataType,
MeanDataType, auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
decltype(thread_welford_desc_I_m_I), AccDataType,
decltype(mean_var_grid_desc_mblock_mperblock_nblock), MeanDataType,
tensor_operation::element_wise::PassThrough, decltype(thread_welford_desc_I_m_I),
Sequence<1, PostShuffleThreadSliceSize_M, 1>, decltype(mean_var_grid_desc_mblock_mperblock_nblock),
Sequence<0, 1, 2>, tensor_operation::element_wise::PassThrough,
1, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
1, Sequence<0, 1, 2>,
InMemoryDataOperationEnum::Set, 1,
1, 1,
false>{mean_var_grid_desc_mblock_mperblock_nblock, InMemoryDataOperationEnum::Set,
make_multi_index(block_work_idx[I0], // mblock 1,
shuffleMPerBlock * i + false>{
post_shuffle_thread_data_idx_begin[I0], // mperblock mean_var_grid_desc_mblock_mperblock_nblock,
block_work_idx[I1]), // nblock make_multi_index(block_work_idx[I0], // mblock
tensor_operation::element_wise::PassThrough{}}; shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
auto var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< block_work_idx[I1]), // nblock
AccDataType, tensor_operation::element_wise::PassThrough{}};
VarDataType,
decltype(thread_welford_desc_I_m_I), auto var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
decltype(mean_var_grid_desc_mblock_mperblock_nblock), AccDataType,
tensor_operation::element_wise::PassThrough, VarDataType,
Sequence<1, PostShuffleThreadSliceSize_M, 1>, decltype(thread_welford_desc_I_m_I),
Sequence<0, 1, 2>, decltype(mean_var_grid_desc_mblock_mperblock_nblock),
1, tensor_operation::element_wise::PassThrough,
1, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
InMemoryDataOperationEnum::Set, Sequence<0, 1, 2>,
1, 1,
false>{mean_var_grid_desc_mblock_mperblock_nblock, 1,
make_multi_index(block_work_idx[I0], // mblock InMemoryDataOperationEnum::Set,
shuffleMPerBlock * i + 1,
post_shuffle_thread_data_idx_begin[I0], // mperblock false>{
block_work_idx[I1]), // nblock mean_var_grid_desc_mblock_mperblock_nblock,
tensor_operation::element_wise::PassThrough{}}; make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< post_shuffle_thread_data_idx_begin[I0], // mperblock
int32_t, block_work_idx[I1]), // nblock
int32_t, tensor_operation::element_wise::PassThrough{}};
decltype(thread_welford_desc_I_m_I),
decltype(count_grid_desc_mblock_mperblock_nblock), mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
tensor_operation::element_wise::PassThrough, make_tuple(I0, I0, I0),
Sequence<1, PostShuffleThreadSliceSize_M, 1>, mean_thread_buf,
Sequence<0, 1, 2>, mean_var_grid_desc_mblock_mperblock_nblock,
1, mean_grid_buf);
1,
InMemoryDataOperationEnum::Set, var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
1, make_tuple(I0, I0, I0),
false>{count_grid_desc_mblock_mperblock_nblock, var_thread_buf,
make_multi_index(block_work_idx[I0], // mblock mean_var_grid_desc_mblock_mperblock_nblock,
shuffleMPerBlock * i + var_grid_buf);
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock if(i == 0 && block_work_idx[I0] == 0 &&
tensor_operation::element_wise::PassThrough{}}; post_shuffle_thread_data_idx_begin[I0] == 0)
{
mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
make_tuple(I0, I0, I0), int32_t,
mean_thread_buf, int32_t,
mean_var_grid_desc_mblock_mperblock_nblock, decltype(thread_welford_desc_I_m_I),
mean_grid_buf); decltype(count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
make_tuple(I0, I0, I0), Sequence<0, 1, 2>,
var_thread_buf, 1,
mean_var_grid_desc_mblock_mperblock_nblock, 1,
var_grid_buf); InMemoryDataOperationEnum::Set,
1,
count_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, false>{count_grid_desc_mblock_mperblock_nblock,
make_tuple(I0, I0, I0), make_multi_index(
count_thread_buf, block_work_idx[I0], // mblock
count_grid_desc_mblock_mperblock_nblock, shuffleMPerBlock * i +
welford_count_grid_buf); post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
count_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
count_thread_buf,
count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf);
}
}
}); });
} // shuffle C + Ds + welford + write out } // shuffle C + Ds + welford + write out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment