Commit 50d7b4fc authored by Chao Liu's avatar Chao Liu
Browse files

shuffle more than one M/NRepeat

parent 26ce5e12
......@@ -695,12 +695,52 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
#else
// shuffle and write out
{
#if 1
// TODO: make it tunable
constexpr index_t MRepeatPerShuffle_CCopy = 1;
constexpr index_t NRepeatPerShuffle_CCopy = 1;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MRepeatThread_CCopy = 1;
constexpr index_t MThread_CCopy = 32;
constexpr index_t NRepeatThread_CCopy = 1;
constexpr index_t NThread_CCopy = 8;
// vector length for blockwise copy from LDS to global
constexpr index_t NScalarPerVector_CCopy = 8;
#else
// TODO: make it tunable
constexpr index_t MRepeatPerShuffle_CCopy = 1;
constexpr index_t NRepeatPerShuffle_CCopy = 2;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MRepeatThread_CCopy = 1;
constexpr index_t MThread_CCopy = 16;
constexpr index_t NRepeatThread_CCopy = 2;
constexpr index_t NThread_CCopy = 8;
// vector length for blockwise copy from LDS to global
constexpr index_t NScalarPerVector_CCopy = 8;
#endif
static_assert(MRepeat % MRepeatPerShuffle_CCopy == 0 &&
NRepeat % NRepeatPerShuffle_CCopy == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
constexpr index_t MPerBlock_CCopy = MWave * MPerXdl;
constexpr index_t NPerBlock_CCopy = NWave * NPerXdl;
constexpr index_t MPerThread_CCopy = MPerBlock_CCopy / MThread_CCopy;
constexpr index_t NPerThread_CCopy = NPerBlock_CCopy / NThread_CCopy;
constexpr index_t MRepeatPerThread_CCopy =
MRepeatPerShuffle_CCopy / MRepeatThread_CCopy;
constexpr index_t NRepeatPerThread_CCopy =
NRepeatPerShuffle_CCopy / NRepeatThread_CCopy;
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
......@@ -719,20 +759,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MThread_CCopy = 32;
constexpr index_t NThread_CCopy = 8;
constexpr index_t MPerThread_CCopy = MPerBlock_CCopy / MThread_CCopy;
constexpr index_t NPerThread_CCopy = NPerBlock_CCopy / NThread_CCopy;
constexpr auto c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, Number<MPerBlock_CCopy>{}, I1, I1, Number<NPerBlock_CCopy>{}));
static_assert(c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
.GetElementSpaceSize() == 64 * 64,
"wrong!");
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeatPerShuffle_CCopy>{},
Number<MWave * MPerXdl>{},
I1,
Number<NRepeatPerShuffle_CCopy>{},
Number<NWave * NPerXdl>{}));
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatC*>(p_shared),
......@@ -741,12 +774,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_tuple(make_freeze_transform(I0), // freeze mblock
make_pass_through_transform(I1), // M0 (MRepeat) per shuffle = 1
make_tuple(make_freeze_transform(I0), // freeze mblock
make_pass_through_transform(
Number<MRepeatPerShuffle_CCopy>{}), // M0 (MRepeat) per shuffle
make_unmerge_transform(
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform(I0), // freeze nblock
make_pass_through_transform(I1), // N0 (NRepeat) per shuffle = 1
make_pass_through_transform(
Number<NRepeatPerShuffle_CCopy>{}), // N0 (NRepeat) per shuffle
make_unmerge_transform(
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple(Sequence<0>{},
......@@ -793,39 +828,53 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
make_multi_index(n_thread_data_on_block));
// VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<I1, I1, I1, I1, M2, I1, M4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<MRepeatPerShuffle_CCopy, NRepeatPerShuffle_CCopy, I1, I1, M2, I1, M4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v4<
BlockSize, // index_t BlockSize,
ck::tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, 1, MPerBlock_CCopy, 1, 1, NPerBlock_CCopy>, // BlockSliceLengths,
Sequence<1, 1, MPerThread_CCopy, 1, 1, NPerThread_CCopy>, // ThreadSliceLengths,
Sequence<1, 1, MThread_CCopy, 1, 1, NThread_CCopy>, // ThreadClusterLengths,
BlockSize, // index_t BlockSize,
ck::tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
MRepeatPerShuffle_CCopy,
MPerBlock_CCopy,
1,
NRepeatPerShuffle_CCopy,
NPerBlock_CCopy>, // BlockSliceLengths,
Sequence<1,
MRepeatPerShuffle_CCopy,
MPerThread_CCopy,
1,
NRepeatPerShuffle_CCopy,
NPerThread_CCopy>, // ThreadSliceLengths,
Sequence<1,
MRepeatPerThread_CCopy,
MThread_CCopy,
1,
NRepeatPerThread_CCopy,
NThread_CCopy>, // ThreadClusterLengths,
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData,
FloatC, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
decltype(c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
......@@ -833,8 +882,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
Sequence<0, 1, 2, 3, 4, 5>, // typename DstDimAccessOrder,
5, // index_t SrcVectorDim,
5, // index_t DstVectorDim,
NPerThread_CCopy, // index_t SrcScalarPerVector,
NPerThread_CCopy, // index_t DstScalarPerVector,
NScalarPerVector_CCopy, // index_t SrcScalarPerVector,
NScalarPerVector_CCopy, // index_t DstScalarPerVector,
1, // index_t SrcScalarStrideInVector,
1, // index_t DstScalarStrideInVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
......@@ -845,24 +894,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
ck::tensor_operation::element_wise::PassThrough{}};
constexpr auto mrepeat_forward_step = make_multi_index(0, 1, 0, 0, 0, 0);
constexpr auto nrepeat_forward_step = make_multi_index(0, 0, 0, 0, 1, 0);
constexpr auto nrepeat_backward_step = make_multi_index(0, 0, 0, 0, -1, 0);
// make sure all ds_read from GEMM is completed
block_sync_lds();
constexpr auto mrepeat_forward_step =
make_multi_index(0, MRepeatPerShuffle_CCopy, 0, 0, 0, 0);
constexpr auto nrepeat_forward_step =
make_multi_index(0, 0, 0, 0, NRepeatPerShuffle_CCopy, 0);
constexpr auto nrepeat_backward_step =
make_multi_index(0, 0, 0, 0, -NRepeatPerShuffle_CCopy, 0);
static_for<0, MRepeat, 1>{}([&](auto mrepeat_iter) {
static_for<0, MRepeat, MRepeatPerShuffle_CCopy>{}([&](auto mrepeat_iter) {
constexpr auto mrepeat = mrepeat_iter;
static_for<0, NRepeat, 1>{}([&](auto nrepeat_iter) {
constexpr bool nrepeat_forward_sweep = (mrepeat % 2 == 0);
static_for<0, NRepeat, NRepeatPerShuffle_CCopy>{}([&](auto nrepeat_iter) {
constexpr bool nrepeat_forward_sweep =
(mrepeat % (2 * MRepeatPerShuffle_CCopy) == 0);
constexpr index_t nrepeat_value =
nrepeat_forward_sweep ? nrepeat_iter : (NRepeat - nrepeat_iter - 1);
nrepeat_forward_sweep ? nrepeat_iter
: (NRepeat - nrepeat_iter - NRepeatPerShuffle_CCopy);
constexpr auto nrepeat = Number<nrepeat_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
......@@ -871,7 +925,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
// make sure ds_write from c_thread_copy_vgpr_to_lds is completed
// make sure it's safe to do ds_read
block_sync_lds();
// LDS to global
......@@ -881,11 +935,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
c_grid_buf);
// make sure ds_read from c_block_copy_lds_to_global is completed
block_sync_lds();
// move on nrepeat dimension
if constexpr(nrepeat_forward_sweep && (nrepeat < NRepeat - 1))
if constexpr(nrepeat_forward_sweep &&
(nrepeat < NRepeat - NRepeatPerShuffle_CCopy))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
......@@ -900,7 +952,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
});
// move on mrepeat dimension
if constexpr(mrepeat < MRepeat - 1)
if constexpr(mrepeat < MRepeat - MRepeatPerShuffle_CCopy)
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment