Commit 50d7b4fc authored by Chao Liu's avatar Chao Liu
Browse files

shuffle more than one M/NRepeat

parent 26ce5e12
...@@ -695,12 +695,52 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -695,12 +695,52 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
#else #else
// shuffle and write out // shuffle and write out
{ {
#if 1
// TODO: make it tunable
constexpr index_t MRepeatPerShuffle_CCopy = 1;
constexpr index_t NRepeatPerShuffle_CCopy = 1;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MRepeatThread_CCopy = 1;
constexpr index_t MThread_CCopy = 32;
constexpr index_t NRepeatThread_CCopy = 1;
constexpr index_t NThread_CCopy = 8;
// vector length for blockwise copy from LDS to global
constexpr index_t NScalarPerVector_CCopy = 8;
#else
// TODO: make it tunable
constexpr index_t MRepeatPerShuffle_CCopy = 1;
constexpr index_t NRepeatPerShuffle_CCopy = 2;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MRepeatThread_CCopy = 1;
constexpr index_t MThread_CCopy = 16;
constexpr index_t NRepeatThread_CCopy = 2;
constexpr index_t NThread_CCopy = 8;
// vector length for blockwise copy from LDS to global
constexpr index_t NScalarPerVector_CCopy = 8;
#endif
static_assert(MRepeat % MRepeatPerShuffle_CCopy == 0 &&
NRepeat % NRepeatPerShuffle_CCopy == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
constexpr index_t MPerBlock_CCopy = MWave * MPerXdl; constexpr index_t MPerBlock_CCopy = MWave * MPerXdl;
constexpr index_t NPerBlock_CCopy = NWave * NPerXdl; constexpr index_t NPerBlock_CCopy = NWave * NPerXdl;
constexpr index_t MPerThread_CCopy = MPerBlock_CCopy / MThread_CCopy;
constexpr index_t NPerThread_CCopy = NPerBlock_CCopy / NThread_CCopy;
constexpr index_t MRepeatPerThread_CCopy =
MRepeatPerShuffle_CCopy / MRepeatThread_CCopy;
constexpr index_t NRepeatPerThread_CCopy =
NRepeatPerShuffle_CCopy / NRepeatThread_CCopy;
// TODO: hacky, fix it! // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -719,20 +759,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -719,20 +759,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MThread_CCopy = 32;
constexpr index_t NThread_CCopy = 8;
constexpr index_t MPerThread_CCopy = MPerBlock_CCopy / MThread_CCopy;
constexpr index_t NPerThread_CCopy = NPerBlock_CCopy / NThread_CCopy;
constexpr auto c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl = constexpr auto c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl =
make_naive_tensor_descriptor_packed(make_tuple( make_naive_tensor_descriptor_packed(make_tuple(I1,
I1, I1, Number<MPerBlock_CCopy>{}, I1, I1, Number<NPerBlock_CCopy>{})); Number<MRepeatPerShuffle_CCopy>{},
Number<MWave * MPerXdl>{},
static_assert(c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl I1,
.GetElementSpaceSize() == 64 * 64, Number<NRepeatPerShuffle_CCopy>{},
"wrong!"); Number<NWave * NPerXdl>{}));
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatC*>(p_shared), static_cast<FloatC*>(p_shared),
...@@ -741,12 +774,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -741,12 +774,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_tuple(make_freeze_transform(I0), // freeze mblock make_tuple(make_freeze_transform(I0), // freeze mblock
make_pass_through_transform(I1), // M0 (MRepeat) per shuffle = 1 make_pass_through_transform(
Number<MRepeatPerShuffle_CCopy>{}), // M0 (MRepeat) per shuffle
make_unmerge_transform( make_unmerge_transform(
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform(I0), // freeze nblock make_freeze_transform(I0), // freeze nblock
make_pass_through_transform(I1), // N0 (NRepeat) per shuffle = 1 make_pass_through_transform(
Number<NRepeatPerShuffle_CCopy>{}), // N0 (NRepeat) per shuffle
make_unmerge_transform( make_unmerge_transform(
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -793,39 +828,53 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -793,39 +828,53 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
make_multi_index(n_thread_data_on_block)); make_multi_index(n_thread_data_on_block));
// VGPR to LDS // VGPR to LDS
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc, FloatAcc,
FloatC, FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Sequence<I1, I1, I1, I1, M2, I1, M4, I1>, Sequence<MRepeatPerShuffle_CCopy, NRepeatPerShuffle_CCopy, I1, I1, M2, I1, M4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, 7,
1, 1,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
1, 1,
true>{ true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, make_multi_index(0,
make_multi_index(0, 0,
0, m_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I1], n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1], m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I2], m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I3], m_thread_data_on_block_idx[I4],
m_thread_data_on_block_idx[I4], n_thread_data_on_block_idx[I2]),
n_thread_data_on_block_idx[I2]), ck::tensor_operation::element_wise::PassThrough{}};
ck::tensor_operation::element_wise::PassThrough{}};
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v4< auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v4<
BlockSize, // index_t BlockSize, BlockSize, // index_t BlockSize,
ck::tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation, ck::tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, 1, MPerBlock_CCopy, 1, 1, NPerBlock_CCopy>, // BlockSliceLengths, Sequence<1,
Sequence<1, 1, MPerThread_CCopy, 1, 1, NPerThread_CCopy>, // ThreadSliceLengths, MRepeatPerShuffle_CCopy,
Sequence<1, 1, MThread_CCopy, 1, 1, NThread_CCopy>, // ThreadClusterLengths, MPerBlock_CCopy,
1,
NRepeatPerShuffle_CCopy,
NPerBlock_CCopy>, // BlockSliceLengths,
Sequence<1,
MRepeatPerShuffle_CCopy,
MPerThread_CCopy,
1,
NRepeatPerShuffle_CCopy,
NPerThread_CCopy>, // ThreadSliceLengths,
Sequence<1,
MRepeatPerThread_CCopy,
MThread_CCopy,
1,
NRepeatPerThread_CCopy,
NThread_CCopy>, // ThreadClusterLengths,
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData, FloatC, // typename SrcData,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype(c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl), decltype(c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
decltype(c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl), decltype(c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
...@@ -833,8 +882,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -833,8 +882,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
Sequence<0, 1, 2, 3, 4, 5>, // typename DstDimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename DstDimAccessOrder,
5, // index_t SrcVectorDim, 5, // index_t SrcVectorDim,
5, // index_t DstVectorDim, 5, // index_t DstVectorDim,
NPerThread_CCopy, // index_t SrcScalarPerVector, NScalarPerVector_CCopy, // index_t SrcScalarPerVector,
NPerThread_CCopy, // index_t DstScalarPerVector, NScalarPerVector_CCopy, // index_t DstScalarPerVector,
1, // index_t SrcScalarStrideInVector, 1, // index_t SrcScalarStrideInVector,
1, // index_t DstScalarStrideInVector, 1, // index_t DstScalarStrideInVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun, true, // bool ThreadTransferSrcResetCoordinateAfterRun,
...@@ -845,24 +894,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -845,24 +894,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
constexpr auto mrepeat_forward_step = make_multi_index(0, 1, 0, 0, 0, 0); constexpr auto mrepeat_forward_step =
constexpr auto nrepeat_forward_step = make_multi_index(0, 0, 0, 0, 1, 0); make_multi_index(0, MRepeatPerShuffle_CCopy, 0, 0, 0, 0);
constexpr auto nrepeat_backward_step = make_multi_index(0, 0, 0, 0, -1, 0); constexpr auto nrepeat_forward_step =
make_multi_index(0, 0, 0, 0, NRepeatPerShuffle_CCopy, 0);
// make sure all ds_read from GEMM is completed constexpr auto nrepeat_backward_step =
block_sync_lds(); make_multi_index(0, 0, 0, 0, -NRepeatPerShuffle_CCopy, 0);
static_for<0, MRepeat, 1>{}([&](auto mrepeat_iter) { static_for<0, MRepeat, MRepeatPerShuffle_CCopy>{}([&](auto mrepeat_iter) {
constexpr auto mrepeat = mrepeat_iter; constexpr auto mrepeat = mrepeat_iter;
static_for<0, NRepeat, 1>{}([&](auto nrepeat_iter) { static_for<0, NRepeat, NRepeatPerShuffle_CCopy>{}([&](auto nrepeat_iter) {
constexpr bool nrepeat_forward_sweep = (mrepeat % 2 == 0); constexpr bool nrepeat_forward_sweep =
(mrepeat % (2 * MRepeatPerShuffle_CCopy) == 0);
constexpr index_t nrepeat_value = constexpr index_t nrepeat_value =
nrepeat_forward_sweep ? nrepeat_iter : (NRepeat - nrepeat_iter - 1); nrepeat_forward_sweep ? nrepeat_iter
: (NRepeat - nrepeat_iter - NRepeatPerShuffle_CCopy);
constexpr auto nrepeat = Number<nrepeat_value>{}; constexpr auto nrepeat = Number<nrepeat_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
// VGPR to LDS // VGPR to LDS
c_thread_copy_vgpr_to_lds.Run( c_thread_copy_vgpr_to_lds.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
...@@ -871,7 +925,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -871,7 +925,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf); c_block_buf);
// make sure ds_write from c_thread_copy_vgpr_to_lds is completed // make sure it's safe to do ds_read
block_sync_lds(); block_sync_lds();
// LDS to global // LDS to global
...@@ -881,11 +935,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -881,11 +935,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
c_grid_buf); c_grid_buf);
// make sure ds_read from c_block_copy_lds_to_global is completed
block_sync_lds();
// move on nrepeat dimension // move on nrepeat dimension
if constexpr(nrepeat_forward_sweep && (nrepeat < NRepeat - 1)) if constexpr(nrepeat_forward_sweep &&
(nrepeat < NRepeat - NRepeatPerShuffle_CCopy))
{ {
c_block_copy_lds_to_global.MoveDstSliceWindow( c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
...@@ -900,7 +952,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -900,7 +952,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}); });
// move on mrepeat dimension // move on mrepeat dimension
if constexpr(mrepeat < MRepeat - 1) if constexpr(mrepeat < MRepeat - MRepeatPerShuffle_CCopy)
{ {
c_block_copy_lds_to_global.MoveDstSliceWindow( c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment