Commit ea41fc2f authored by aska-0096's avatar aska-0096
Browse files

Port new layout to v1, v2 pipeline

parent b70bcd86
...@@ -66,7 +66,9 @@ constexpr auto BlockGemmPipeline_Selector() ...@@ -66,7 +66,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
TransposeA,
TransposeB>{};
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{ {
...@@ -89,7 +91,9 @@ constexpr auto BlockGemmPipeline_Selector() ...@@ -89,7 +91,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
TransposeA,
TransposeB>{};
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
...@@ -137,7 +141,9 @@ constexpr auto BlockGemmPipeline_Selector() ...@@ -137,7 +141,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
TransposeA,
TransposeB>{};
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)
{ {
...@@ -160,7 +166,9 @@ constexpr auto BlockGemmPipeline_Selector() ...@@ -160,7 +166,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
TransposeA,
TransposeB>{};
} }
else else
{ {
......
...@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer, ...@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPacks> index_t KPacks,
bool TransposeA,
bool TransposeB>
struct BlockwiseGemmXdlops_pipeline_v1 struct BlockwiseGemmXdlops_pipeline_v1
{ {
}; };
...@@ -55,7 +57,9 @@ template <index_t BlockSize, ...@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack index_t KPack,
bool TransposeA,
bool TransposeB
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
...@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave, ...@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave, ...@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
{ {
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize, using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
...@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave, ...@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack,
TransposeA,
TransposeB>;
using Base::I0; using Base::I0;
using Base::KRepeat; using Base::KRepeat;
using Base::xdlops_gemm; using Base::xdlops_gemm;
...@@ -218,23 +228,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave, ...@@ -218,23 +228,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}), a_block_buf,
a_block_buf, a_thread_desc_,
a_thread_desc_, make_tuple(I0, I0, k0, I0),
make_tuple(m0, I0, k, I0), a_thread_buf);
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, make_tuple(I0, I0, I0, Number<k0 * BMmaKStride>{}),
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_block_buf,
b_block_buf, b_thread_desc_,
b_thread_desc_, make_tuple(I0, I0, k0, I0),
make_tuple(n0, I0, k, I0), b_thread_buf);
b_thread_buf);
});
});
}); });
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
...@@ -279,23 +286,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave, ...@@ -279,23 +286,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
if constexpr(TailNum == TailNumber::Full) if constexpr(TailNum == TailNumber::Full)
{ {
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}), a_block_buf,
a_block_buf, a_thread_desc_,
a_thread_desc_, make_tuple(I0, I0, k0, I0),
make_tuple(m0, I0, k, I0), a_thread_buf);
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, make_tuple(I0, I0, I0, Number<k0 * BMmaKStride>{}),
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_block_buf,
b_block_buf, b_thread_desc_,
b_thread_desc_, make_tuple(I0, I0, k0, I0),
make_tuple(n0, I0, k, I0), b_thread_buf);
b_thread_buf);
});
});
}); });
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
...@@ -354,7 +358,9 @@ template <index_t BlockSize, ...@@ -354,7 +358,9 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack index_t KPack,
bool TransposeA,
bool TransposeB
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave, struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
...@@ -376,7 +382,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave, ...@@ -376,7 +382,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -395,7 +403,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave, ...@@ -395,7 +403,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
{ {
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize, using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
...@@ -416,7 +426,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave, ...@@ -416,7 +426,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack,
TransposeA,
TransposeB>;
using Base::A_K1; using Base::A_K1;
using Base::B_K1; using Base::B_K1;
using Base::I0; using Base::I0;
...@@ -520,22 +532,18 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave, ...@@ -520,22 +532,18 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(I0, I0, I0, Number<k0 * KPerInnerLoop>{}),
make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}), a_block_buf,
a_block_buf, a_thread_desc_,
a_thread_desc_, make_tuple(I0, I0, k0, I0),
make_tuple(m0, I0, k0, I0), a_thread_buf);
a_thread_buf); b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
static_for<0, NRepeat, 1>{}([&](auto n0) { make_tuple(I0, I0, I0, Number<k0 * KPerInnerLoop>{}),
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_block_buf,
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}), b_thread_desc_,
b_block_buf, make_tuple(I0, I0, k0, I0),
b_thread_desc_, b_thread_buf);
make_tuple(n0, I0, k0, I0),
b_thread_buf);
});
});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
// but except the first, as we can shorten non-MAC cluster a bit and there's no // but except the first, as we can shorten non-MAC cluster a bit and there's no
...@@ -614,22 +622,18 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave, ...@@ -614,22 +622,18 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
{ {
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(I0, I0, I0, Number<k0 * KPerInnerLoop>{}),
make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}), a_block_buf,
a_block_buf, a_thread_desc_,
a_thread_desc_, make_tuple(I0, I0, k0, I0),
make_tuple(m0, I0, k0, I0), a_thread_buf);
a_thread_buf); b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
static_for<0, NRepeat, 1>{}([&](auto n0) { make_tuple(I0, I0, I0, Number<k0 * KPerInnerLoop>{}),
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_block_buf,
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}), b_thread_desc_,
b_block_buf, make_tuple(I0, I0, k0, I0),
b_thread_desc_, b_thread_buf);
make_tuple(n0, I0, k0, I0),
b_thread_buf);
});
});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
if constexpr(k0.value != 0 || KRepeat == 1) if constexpr(k0.value != 0 || KRepeat == 1)
...@@ -703,28 +707,80 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave, ...@@ -703,28 +707,80 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
Number<NRepeat * KPerInnerLoop>{}, Number<NRepeat * KPerInnerLoop>{},
I1)); I1));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType, template <bool Transpose>
ComputeDataType, struct AThreadCopySelector;
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_), template <>
Sequence<1, 1, 1, KPerInnerLoop>, struct AThreadCopySelector<false>
Sequence<0, 1, 2, 3>, {
3, using type = ThreadwiseTensorSliceTransfer_v5<ADataType,
A_K1, ComputeDataType,
A_K1>; decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType, Sequence<MRepeat, 1, 1, KPerInnerLoop>,
ComputeDataType, Sequence<0, 1, 2, 3>,
decltype(b_block_desc_n0_n1_n2_k), Sequence<0, 1, 2, 3>,
decltype(b_thread_desc_), 3,
Sequence<1, 1, 1, KPerInnerLoop>, 3,
Sequence<0, 1, 2, 3>, A_K1,
3, A_K1>;
B_K1, };
B_K1>;
template <>
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; struct AThreadCopySelector<true>
BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; {
using type = ThreadwiseTensorSliceTransfer_v5<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<MRepeat, 1, 1, KPerInnerLoop>,
Sequence<3, 1, 2, 0>,
Sequence<0, 1, 2, 3>,
0,
3,
MRepeat,
A_K1>;
};
template <bool Transpose>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_v5<BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPerInnerLoop>,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
3,
3,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v5<BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPerInnerLoop>,
Sequence<3, 1, 2, 0>,
Sequence<0, 1, 2, 3>,
0,
3,
NRepeat,
B_K1>;
};
typename AThreadCopySelector<TransposeA>::type a_thread_copy_{
Base::CalculateAThreadOriginDataIndex()};
typename BThreadCopySelector<TransposeB>::type b_thread_copy_{
Base::CalculateBThreadOriginDataIndex()};
using Base::c_thread_desc_; using Base::c_thread_desc_;
}; };
......
...@@ -146,7 +146,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -146,7 +146,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr index_t KPack = static constexpr index_t KPack =
math::max(math::gcd(AK1Number, BK1Number), math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -1016,16 +1016,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1016,16 +1016,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto b_block_space_size_aligned = math::integer_least_multiple( constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS return a_block_space_size_aligned * sizeof(ADataType) +
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = b_block_space_size_aligned * sizeof(BDataType);
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType)),
c_block_size * sizeof(CShuffleDataType));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -1713,202 +1705,90 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1713,202 +1705,90 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out // Epilogue
{ constexpr auto c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 =
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && blockwise_gemm_pipeline.GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4();
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared_0),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C constexpr auto c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4 =
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( blockwise_gemm_pipeline.GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4();
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} constexpr auto M0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<0>{});
}); constexpr auto M1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<1>{});
} constexpr auto N0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<2>{});
constexpr auto M2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<3>{});
constexpr auto M3 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<4>{});
constexpr auto N1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<5>{});
constexpr auto N2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 = transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<2, 3, 5, 6, 9>{}, Sequence<1>{}, Sequence<4, 7, 8>{}));
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexContiguous(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Typecast -> Permute -> Coalesced vector store
auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4<
AccDataType,
CDataType,
decltype(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
decltype(c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
CElementwiseOperation,
Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 9, 8>,
9,
8,
M4,
N2,
InMemoryDataOperationEnum::Set,
false>{c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
make_multi_index(block_m_id,
block_n_id,
m_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I4]),
c_element_op};
c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
c_grid_buf);
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment