Commit 72a488cc authored by aska-0096's avatar aska-0096
Browse files

All layout sanity pass

parent ea41fc2f
...@@ -154,9 +154,9 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -154,9 +154,9 @@ struct BlockwiseGemmXdlops_pipeline_base
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
// Contiguous output tile // NContiguous output tile
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto CalculateCThreadOriginDataIndexContiguous(Number<m0>, __device__ static auto CalculateCThreadOriginDataIndexNContiguous(Number<m0>,
Number<n0>, Number<n0>,
Number<xdlops_i>, Number<xdlops_i>,
Number<blk_i>) Number<blk_i>)
...@@ -186,6 +186,38 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -186,6 +186,38 @@ struct BlockwiseGemmXdlops_pipeline_base
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
// MNContiguous output tile
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto CalculateCThreadOriginDataIndexMNContiguous(Number<m0>,
Number<n0>,
Number<xdlops_i>,
Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MWaves, MPerXDL, MRepeat))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NWaves, NPerXDL, NRepeat))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(waveId_m, blk_idx[I0], m0))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(waveId_n, blk_idx[I1], n0))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto __device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>) CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
...@@ -246,19 +278,30 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -246,19 +278,30 @@ struct BlockwiseGemmXdlops_pipeline_base
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N)); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
} }
// Contiguous output tile // N-Contiguous output tile
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4() GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4()
{ {
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, Number<MRepeat>{}, I1, I1, M0, M1, N, Number<NRepeat>{}, M2)); make_tuple(I1, I1, Number<MRepeat>{}, I1, I1, M0, I1, I1, Number<NRepeat>{}, M2));
}
// MN-Contiguous output tile
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MBlock_NBlock_M0_N0_M1_M2_N1_M3_N2_M4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, I1, M0, I1, I1, Number<MRepeat>{}, Number<NRepeat>{}, M2));
} }
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
...@@ -315,6 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -315,6 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_base
return xdlops_gemm.MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4(c_block_desc_m0_n0_m1_n1_m2_n2); return xdlops_gemm.MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4(c_block_desc_m0_n0_m1_n1_m2_n2);
} }
// TransposeA
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_M2_N1_M3_N2_M4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_M2_N1_M3_N2_M4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
...@@ -435,7 +492,7 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -435,7 +492,7 @@ struct BlockwiseGemmXdlops_pipeline_base
decltype(b_block_desc_n0_n1_n2_k), decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPack>, Sequence<NRepeat, 1, 1, KPack>,
Sequence<0, 1, 2, 3>, Sequence<2, 0, 1, 3>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
3, 3,
......
...@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer, ...@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPacks> index_t KPacks,
bool TransposeA,
bool TransposeB>
struct BlockwiseGemmXdlops_pipeline_v4 struct BlockwiseGemmXdlops_pipeline_v4
{ {
}; };
...@@ -55,7 +57,9 @@ template <index_t BlockSize, ...@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack index_t KPack,
bool TransposeA,
bool TransposeB
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
{ {
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize, using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
...@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack,
TransposeA,
TransposeB>;
using Base::I0; using Base::I0;
using Base::I1; using Base::I1;
using Base::KRepeat; using Base::KRepeat;
...@@ -298,23 +308,19 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -298,23 +308,19 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}), make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(I0), a_block_buf.At(I0),
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(I0, I0, k, I0),
a_thread_bufs(I0)); a_thread_bufs(I0));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0), b_block_buf.At(I0),
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k, I0), make_tuple(I0, I0, k, I0),
b_thread_bufs(I0)); b_thread_bufs(I0));
}); });
});
});
// Global prefetch 3 // Global prefetch 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
...@@ -349,24 +355,19 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -349,24 +355,19 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}), make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf), a_block_buf.At(lds_read_buf),
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(I0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) { b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_thread_copy_.Run( make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf), b_block_buf.At(lds_read_buf),
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k, I0), make_tuple(I0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf)); b_thread_bufs(lds_read_reg_buf));
}); });
});
});
a_blockwise_copy.RunWrite( a_blockwise_copy.RunWrite(
a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
...@@ -430,23 +431,19 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -430,23 +431,19 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}), make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf), a_block_buf.At(lds_read_buf),
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(I0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf), b_block_buf.At(lds_read_buf),
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k, I0), make_tuple(I0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf)); b_thread_bufs(lds_read_reg_buf));
}); });
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
...@@ -489,23 +486,19 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -489,23 +486,19 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}), make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf), a_block_buf.At(lds_read_buf),
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(I0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf), b_block_buf.At(lds_read_buf),
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k, I0), make_tuple(I0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf)); b_thread_bufs(lds_read_reg_buf));
}); });
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
......
...@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer, ...@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPacks> index_t KPacks,
bool TransposeA,
bool TransposeB>
struct BlockwiseGemmXdlops_pipeline_v5 struct BlockwiseGemmXdlops_pipeline_v5
{ {
}; };
...@@ -55,7 +57,9 @@ template <index_t BlockSize, ...@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack index_t KPack,
bool TransposeA,
bool TransposeB
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
{ {
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize, using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
...@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack,
TransposeA,
TransposeB>;
using Base::A_K1; using Base::A_K1;
using Base::B_K1; using Base::B_K1;
using Base::I0; using Base::I0;
...@@ -381,22 +391,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -381,22 +391,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
});
// main body // main body
if constexpr(HasMainLoop) if constexpr(HasMainLoop)
...@@ -449,26 +456,24 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -449,26 +456,24 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
});
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k, a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k, b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
}); });
});
HotLoopScheduler(); HotLoopScheduler();
}; };
...@@ -517,25 +522,24 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -517,25 +522,24 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
});
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k, a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k, b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
}); });
});
HotLoopScheduler(); HotLoopScheduler();
}; };
...@@ -567,26 +571,24 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -567,26 +571,24 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
});
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k, a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k, b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
}); });
});
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
...@@ -636,28 +638,81 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -636,28 +638,81 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I1, I1, Number<KPack>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I1, I1, Number<KPack>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType, template <bool Transpose>
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_v5<ADataType,
ComputeDataType, ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k), decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>, Sequence<MRepeat, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
3,
A_K1, A_K1,
A_K1>; A_K1>;
};
template <>
struct AThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v5<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<MRepeat, 1, 1, KPack>,
Sequence<3, 1, 2, 0>,
Sequence<0, 1, 2, 3>,
0,
3,
MRepeat,
A_K1>;
};
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType, template <bool Transpose>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_v5<BDataType,
ComputeDataType, ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k), decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>, Sequence<NRepeat, 1, 1, KPack>,
Sequence<2, 0, 1, 3>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
3,
B_K1, B_K1,
B_K1>; B_K1>;
};
template <>
struct BThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v5<BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPack>,
Sequence<3, 1, 2, 0>,
Sequence<0, 1, 2, 3>,
0,
3,
NRepeat,
B_K1>;
};
typename AThreadCopySelector<TransposeA>::type a_thread_copy_{
Base::CalculateAThreadOriginDataIndex()};
typename BThreadCopySelector<TransposeB>::type b_thread_copy_{
Base::CalculateBThreadOriginDataIndex()};
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
using Base::c_thread_desc_; using Base::c_thread_desc_;
}; };
......
...@@ -235,25 +235,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -235,25 +235,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
Number<MNWaves>{}, Number<MNPerXdl>{}, Number<MNXdlPerWave>{}))), Number<MNWaves>{}, Number<MNPerXdl>{}, Number<MNXdlPerWave>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<1, 2, 0>{})); make_tuple(Sequence<3>{}, Sequence<1, 2, 0>{}));
#if 0
constexpr auto mma_transformed =
transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNWaves>{}, Number<MNPerXdl>{}, Number<MNXdlPerWave>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor(
mma_transformed,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_pass_through_transform(Number<MNWaves>{}),
make_pass_through_transform(Number<MNPerXdl>{}),
make_pass_through_transform(Number<MNXdlPerWave>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<0>{}));
#endif
} }
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
...@@ -448,21 +429,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -448,21 +429,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto b_mma_desc = [&]() {
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return MakeGemmMmaTileDescriptorCongruous<NXdlPerWave, NWaves, NPerXdl>( return MakeGemmMmaTileDescriptorCongruous<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{}); BBlockDesc_BK0_N_BK1{});
} }
}();
return b_mma_desc;
}
__host__ __device__ static auto __host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
...@@ -478,18 +447,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -478,18 +447,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
}(); }();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N // pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw, return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_tuple(make_right_pad_transform(M, MPad - M),
...@@ -497,33 +454,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -497,33 +454,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
struct Problem struct Problem
{ {
...@@ -723,92 +653,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -723,92 +653,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else // ColumnMajor A else // ColumnMajor A
{ {
#if 0
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / MPerXdl;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
? 1
: 128 / (AK1Number * M0 * sizeof(ADataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=n0
constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
? 1
: ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
? M0
: 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * M1>{},
Number<kfold * M0 / mpair>{},
Number<mpair>{},
AK1Number));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<M1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<M0 / mpair>{}, Number<mpair>{}, Number<M1>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
#endif
static_assert(ABlockTransferSrcScalarPerVector % MXdlPerWave == 0); static_assert(ABlockTransferSrcScalarPerVector % MXdlPerWave == 0);
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
...@@ -867,89 +711,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -867,89 +711,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else // RowMajor B else // RowMajor B
{ {
#if 0
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / NPerXdl;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
? 1
: 128 / (BK1Number * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
? 1
: ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
? N0
: 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * N1>{},
Number<kfold * N0 / npair>{},
Number<npair>{},
BK1Number));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<N1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<N0 / npair>{}, Number<npair>{}, Number<N1>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_lds_block_desc_bk0_n_bk1;
#endif
static_assert(BBlockTransferSrcScalarPerVector % NXdlPerWave == 0); static_assert(BBlockTransferSrcScalarPerVector % NXdlPerWave == 0);
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
...@@ -958,19 +719,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -958,19 +719,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
} }
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() __device__ static constexpr auto GetCThreadDescriptor_MBlock_NBlock()
{ {
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); {
return BlockwiseGemmPipe::GetCThreadDescriptor_MBlock_NBlock_M0_N0_M1_M2_N1_M3_N2_M4();
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = }
make_naive_tensor_descriptor_packed( else
make_tuple(I1, {
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{}, return BlockwiseGemmPipe::GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4();
I1, }
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
using BlockwiseGemmPipe = using BlockwiseGemmPipe =
...@@ -1413,9 +1171,96 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1413,9 +1171,96 @@ struct GridwiseGemm_xdl_cshuffle_v3
num_k_block_main_loop); num_k_block_main_loop);
// Epilogue // Epilogue
constexpr auto c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 = constexpr auto c_thread_desc_mblock_nblock = GetCThreadDescriptor_MBlock_NBlock();
blockwise_gemm_pipeline.GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4();
auto c_block_trait = [&]() {
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
constexpr auto c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_M2_N1_M3_N2_M4();
constexpr auto M0 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<0>{});
constexpr auto N0 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<1>{});
constexpr auto M1 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<2>{});
constexpr auto M2 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<3>{});
constexpr auto N1 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<4>{});
constexpr auto M3 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<5>{});
constexpr auto N2 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4 =
transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M4, M3)),
make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{},
Sequence<2, 4, 5, 9, 7>{},
Sequence<1>{},
Sequence<3, 6, 8>{}));
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexMNContiguous(
I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Typecast -> Permute -> Coalesced vector store
auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4<
AccDataType,
CDataType,
decltype(c_thread_desc_mblock_nblock),
decltype(c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4),
CElementwiseOperation,
Sequence<I1, I1, I1, I1, M1, I1, I1, M3, N2, M4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<0, 1, 2, 3, 4, 5, 6, 8, 9, 7>,
9,
8,
M4,
N2,
InMemoryDataOperationEnum::Set,
false>{c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4,
make_multi_index(block_m_id,
block_n_id,
m_thread_data_on_block_idx[I0],
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I4]),
c_element_op};
return make_tuple(c_thread_copy_vgpr_to_global,
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4);
}
else
{
constexpr auto c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4 = constexpr auto c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4(); blockwise_gemm_pipeline.GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4();
...@@ -1428,18 +1273,22 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1428,18 +1273,22 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto N2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{}); constexpr auto N2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{}); constexpr auto M4 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 = transform_tensor_descriptor( const auto c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 =
transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock), make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)), make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_pass_through_transform(problem.NBlock), make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))), make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple( make_tuple(Sequence<0>{},
Sequence<0>{}, Sequence<2, 3, 5, 6, 9>{}, Sequence<1>{}, Sequence<4, 7, 8>{})); Sequence<2, 3, 5, 6, 9>{},
Sequence<1>{},
Sequence<4, 7, 8>{}));
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexContiguous(I0, I0, I0, I0); blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexNContiguous(
I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
...@@ -1454,7 +1303,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1454,7 +1303,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block)); make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
...@@ -1467,7 +1317,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1467,7 +1317,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4< auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4<
AccDataType, AccDataType,
CDataType, CDataType,
decltype(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4), decltype(c_thread_desc_mblock_nblock),
decltype(c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4), decltype(c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
CElementwiseOperation, CElementwiseOperation,
Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>, Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>,
...@@ -1491,10 +1341,18 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1491,10 +1341,18 @@ struct GridwiseGemm_xdl_cshuffle_v3
m_thread_data_on_block_idx[I4]), m_thread_data_on_block_idx[I4]),
c_element_op}; c_element_op};
c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4, return make_tuple(c_thread_copy_vgpr_to_global,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4);
}
}();
auto c_thread_copy_vgpr_to_global = c_block_trait.At(Number<0>{});
auto c_grid_desc_mblock_nblock = c_block_trait.At(Number<1>{});
c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4, c_grid_desc_mblock_nblock,
c_grid_buf); c_grid_buf);
} }
...@@ -1607,11 +1465,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1607,11 +1465,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferSrcScalarPerVector,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
...@@ -1638,11 +1496,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1638,11 +1496,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
...@@ -1706,9 +1564,96 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1706,9 +1564,96 @@ struct GridwiseGemm_xdl_cshuffle_v3
num_k_block_main_loop); num_k_block_main_loop);
// Epilogue // Epilogue
constexpr auto c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 = constexpr auto c_thread_desc_mblock_nblock = GetCThreadDescriptor_MBlock_NBlock();
blockwise_gemm_pipeline.GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4();
auto c_block_trait = [&]() {
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
constexpr auto c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_M2_N1_M3_N2_M4();
constexpr auto M0 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<0>{});
constexpr auto N0 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<1>{});
constexpr auto M1 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<2>{});
constexpr auto M2 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<3>{});
constexpr auto N1 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<4>{});
constexpr auto M3 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<5>{});
constexpr auto N2 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4 =
transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M4, M3)),
make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{},
Sequence<2, 4, 5, 9, 7>{},
Sequence<1>{},
Sequence<3, 6, 8>{}));
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexMNContiguous(
I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Typecast -> Permute -> Coalesced vector store
auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4<
AccDataType,
CDataType,
decltype(c_thread_desc_mblock_nblock),
decltype(c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4),
CElementwiseOperation,
Sequence<I1, I1, I1, I1, M1, I1, I1, M3, N2, M4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<0, 1, 2, 3, 4, 5, 6, 8, 9, 7>,
9,
8,
M4,
N2,
InMemoryDataOperationEnum::Set,
false>{c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4,
make_multi_index(block_m_id,
block_n_id,
m_thread_data_on_block_idx[I0],
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I4]),
c_element_op};
return make_tuple(c_thread_copy_vgpr_to_global,
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4);
}
else
{
constexpr auto c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4 = constexpr auto c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4(); blockwise_gemm_pipeline.GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4();
...@@ -1721,18 +1666,22 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1721,18 +1666,22 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto N2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{}); constexpr auto N2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{}); constexpr auto M4 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 = transform_tensor_descriptor( const auto c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 =
transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock), make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)), make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_pass_through_transform(problem.NBlock), make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))), make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple( make_tuple(Sequence<0>{},
Sequence<0>{}, Sequence<2, 3, 5, 6, 9>{}, Sequence<1>{}, Sequence<4, 7, 8>{})); Sequence<2, 3, 5, 6, 9>{},
Sequence<1>{},
Sequence<4, 7, 8>{}));
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexContiguous(I0, I0, I0, I0); blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexNContiguous(
I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
...@@ -1747,7 +1696,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1747,7 +1696,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block)); make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
...@@ -1760,7 +1710,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1760,7 +1710,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4< auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4<
AccDataType, AccDataType,
CDataType, CDataType,
decltype(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4), decltype(c_thread_desc_mblock_nblock),
decltype(c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4), decltype(c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
CElementwiseOperation, CElementwiseOperation,
Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>, Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>,
...@@ -1784,10 +1734,18 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1784,10 +1734,18 @@ struct GridwiseGemm_xdl_cshuffle_v3
m_thread_data_on_block_idx[I4]), m_thread_data_on_block_idx[I4]),
c_element_op}; c_element_op};
c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4, return make_tuple(c_thread_copy_vgpr_to_global,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4);
}
}();
auto c_thread_copy_vgpr_to_global = c_block_trait.At(Number<0>{});
auto c_grid_desc_mblock_nblock = c_block_trait.At(Number<1>{});
c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4, c_grid_desc_mblock_nblock,
c_grid_buf); c_grid_buf);
} }
......
...@@ -983,6 +983,40 @@ struct XdlopsGemm ...@@ -983,6 +983,40 @@ struct XdlopsGemm
Sequence<5>{})); Sequence<5>{}));
} }
// TransposeA
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_M2_N1_M3_N2_M4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
Number<mfma_instr.num_input_blks>{},
Number<mfma_instr.group_size>{})),
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<7>{},
Sequence<6>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2, 3, 5>{},
Sequence<4>{}));
}
// transposed XDL output supporting C' = B' * A' // transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4 // M2_N2 -> M2_N2_N3_N4
template <typename CDesc_M0_N0_M1_N1_M2_N2> template <typename CDesc_M0_N0_M1_N1_M2_N2>
......
...@@ -18,22 +18,6 @@ struct transpose_vectors; ...@@ -18,22 +18,6 @@ struct transpose_vectors;
// transpose fp16 2x2 // transpose fp16 2x2
__device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1) __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1)
{ {
#if 0
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<half_t, 2> vx0{x0}, vx1{x1};
vector_type<half_t, 2> vy0, vy1;
vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
#else
constexpr int32_t m0 = 0x05040100; constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302; constexpr int32_t m1 = 0x07060302;
...@@ -43,7 +27,6 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t ...@@ -43,7 +27,6 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
// index is reversed because of little endianness (least significant bits first) // index is reversed because of little endianness (least significant bits first)
y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0)); y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1)); y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
#endif
} }
template <index_t NX, index_t NY> template <index_t NX, index_t NY>
...@@ -83,6 +66,60 @@ struct transpose_vectors<half_t, NX, NY> ...@@ -83,6 +66,60 @@ struct transpose_vectors<half_t, NX, NY>
} }
}; };
// transpose bf16 2x2
__device__ void
transpose_bf16_2x2(const bhalf2_t& x0, const bhalf2_t& x1, bhalf2_t& y0, bhalf2_t& y1)
{
constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
y0 =
bit_cast<bhalf2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
y1 =
bit_cast<bhalf2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
}
template <index_t NX, index_t NY>
struct transpose_vectors<bhalf_t, NX, NY>
{
// we got [NY * NX] amount of S data to be transposed
static constexpr index_t s_per_x = NY;
static constexpr index_t s_per_y = NX;
using S = bhalf_t;
using VX = vector_type<bhalf_t, s_per_x>;
using VY = vector_type<bhalf_t, s_per_y>;
__device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
StaticallyIndexedArray<VY&, NY>& vy_tuple)
{
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// reference to 2 bhalf2_t data from vx_tuple
const auto& x_s2_0 = vx_tuple[ix].template AsType<bhalf2_t>()[iy / I2];
const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<bhalf2_t>()[iy / I2];
// reference to 2 bhalf2_t data from vy_tuple
auto& y_s2_0 = vy_tuple(iy).template AsType<bhalf2_t>()(ix / I2);
auto& y_s2_1 = vy_tuple(iy + I1).template AsType<bhalf2_t>()(ix / I2);
// transpose
transpose_bf16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1);
});
});
}
};
// transpose int8 4x4 // transpose int8 4x4
__device__ void transpose_int8_4x4(const int8x4_t& x0, __device__ void transpose_int8_4x4(const int8x4_t& x0,
const int8x4_t& x1, const int8x4_t& x1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment