Commit 72a488cc authored by aska-0096's avatar aska-0096
Browse files

All layout sanity pass

parent ea41fc2f
...@@ -154,12 +154,12 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -154,12 +154,12 @@ struct BlockwiseGemmXdlops_pipeline_base
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
// Contiguous output tile // NContiguous output tile
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto CalculateCThreadOriginDataIndexContiguous(Number<m0>, __device__ static auto CalculateCThreadOriginDataIndexNContiguous(Number<m0>,
Number<n0>, Number<n0>,
Number<xdlops_i>, Number<xdlops_i>,
Number<blk_i>) Number<blk_i>)
{ {
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
...@@ -186,6 +186,38 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -186,6 +186,38 @@ struct BlockwiseGemmXdlops_pipeline_base
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
// MNContiguous output tile
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto CalculateCThreadOriginDataIndexMNContiguous(Number<m0>,
Number<n0>,
Number<xdlops_i>,
Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MWaves, MPerXDL, MRepeat))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NWaves, NPerXDL, NRepeat))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(waveId_m, blk_idx[I0], m0))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(waveId_n, blk_idx[I1], n0))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto __device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>) CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
...@@ -246,19 +278,30 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -246,19 +278,30 @@ struct BlockwiseGemmXdlops_pipeline_base
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N)); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
} }
// Contiguous output tile // N-Contiguous output tile
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4() GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4()
{ {
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, Number<MRepeat>{}, I1, I1, M0, M1, N, Number<NRepeat>{}, M2)); make_tuple(I1, I1, Number<MRepeat>{}, I1, I1, M0, I1, I1, Number<NRepeat>{}, M2));
}
// MN-Contiguous output tile
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MBlock_NBlock_M0_N0_M1_M2_N1_M3_N2_M4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, I1, M0, I1, I1, Number<MRepeat>{}, Number<NRepeat>{}, M2));
} }
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
...@@ -315,6 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -315,6 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_base
return xdlops_gemm.MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4(c_block_desc_m0_n0_m1_n1_m2_n2); return xdlops_gemm.MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4(c_block_desc_m0_n0_m1_n1_m2_n2);
} }
// TransposeA
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_M2_N1_M3_N2_M4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_M2_N1_M3_N2_M4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
...@@ -435,7 +492,7 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -435,7 +492,7 @@ struct BlockwiseGemmXdlops_pipeline_base
decltype(b_block_desc_n0_n1_n2_k), decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPack>, Sequence<NRepeat, 1, 1, KPack>,
Sequence<0, 1, 2, 3>, Sequence<2, 0, 1, 3>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
3, 3,
......
...@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer, ...@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPacks> index_t KPacks,
bool TransposeA,
bool TransposeB>
struct BlockwiseGemmXdlops_pipeline_v4 struct BlockwiseGemmXdlops_pipeline_v4
{ {
}; };
...@@ -55,7 +57,9 @@ template <index_t BlockSize, ...@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack index_t KPack,
bool TransposeA,
bool TransposeB
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
{ {
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize, using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
...@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack,
TransposeA,
TransposeB>;
using Base::I0; using Base::I0;
using Base::I1; using Base::I1;
using Base::KRepeat; using Base::KRepeat;
...@@ -298,22 +308,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -298,22 +308,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}), a_block_buf.At(I0),
a_block_buf.At(I0), a_thread_desc_,
a_thread_desc_, make_tuple(I0, I0, k, I0),
make_tuple(m0, I0, k, I0), a_thread_bufs(I0));
a_thread_bufs(I0)); b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
static_for<0, NRepeat, 1>{}([&](auto n0) { make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_block_buf.At(I0),
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_thread_desc_,
b_block_buf.At(I0), make_tuple(I0, I0, k, I0),
b_thread_desc_, b_thread_bufs(I0));
make_tuple(n0, I0, k, I0),
b_thread_bufs(I0));
});
});
}); });
// Global prefetch 3 // Global prefetch 3
...@@ -349,23 +355,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -349,23 +355,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}), a_block_buf.At(lds_read_buf),
a_block_buf.At(lds_read_buf), a_thread_desc_,
a_thread_desc_, make_tuple(I0, I0, k, I0),
make_tuple(m0, I0, k, I0), a_thread_bufs(lds_read_reg_buf));
a_thread_bufs(lds_read_reg_buf)); b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
static_for<0, NRepeat, 1>{}([&](auto n0) { make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_copy_.Run( b_block_buf.At(lds_read_buf),
b_block_desc_n0_n1_n2_k, b_thread_desc_,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(I0, I0, k, I0),
b_block_buf.At(lds_read_buf), b_thread_bufs(lds_read_reg_buf));
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
}); });
a_blockwise_copy.RunWrite( a_blockwise_copy.RunWrite(
...@@ -430,22 +431,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -430,22 +431,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}), a_block_buf.At(lds_read_buf),
a_block_buf.At(lds_read_buf), a_thread_desc_,
a_thread_desc_, make_tuple(I0, I0, k, I0),
make_tuple(m0, I0, k, I0), a_thread_bufs(lds_read_reg_buf));
a_thread_bufs(lds_read_reg_buf)); b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
static_for<0, NRepeat, 1>{}([&](auto n0) { make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_block_buf.At(lds_read_buf),
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_thread_desc_,
b_block_buf.At(lds_read_buf), make_tuple(I0, I0, k, I0),
b_thread_desc_, b_thread_bufs(lds_read_reg_buf));
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
}); });
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
...@@ -489,22 +486,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -489,22 +486,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}), a_block_buf.At(lds_read_buf),
a_block_buf.At(lds_read_buf), a_thread_desc_,
a_thread_desc_, make_tuple(I0, I0, k, I0),
make_tuple(m0, I0, k, I0), a_thread_bufs(lds_read_reg_buf));
a_thread_bufs(lds_read_reg_buf)); b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
static_for<0, NRepeat, 1>{}([&](auto n0) { make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_block_buf.At(lds_read_buf),
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_thread_desc_,
b_block_buf.At(lds_read_buf), make_tuple(I0, I0, k, I0),
b_thread_desc_, b_thread_bufs(lds_read_reg_buf));
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
}); });
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
......
...@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer, ...@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPacks> index_t KPacks,
bool TransposeA,
bool TransposeB>
struct BlockwiseGemmXdlops_pipeline_v5 struct BlockwiseGemmXdlops_pipeline_v5
{ {
}; };
...@@ -55,7 +57,9 @@ template <index_t BlockSize, ...@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack index_t KPack,
bool TransposeA,
bool TransposeB
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeA,
TransposeB>
{ {
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize, using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
...@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack,
TransposeA,
TransposeB>;
using Base::A_K1; using Base::A_K1;
using Base::B_K1; using Base::B_K1;
using Base::I0; using Base::I0;
...@@ -381,22 +391,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -381,22 +391,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, make_tuple(I0, I0, I0, I0),
make_tuple(m0, I0, I0, I0), a_block_buf,
a_block_buf, a_thread_desc_,
a_thread_desc_, make_tuple(I0, I0, I0, I0),
make_tuple(m0, I0, I0, I0), a_thread_buf);
a_thread_buf);
}); b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
static_for<0, NRepeat, 1>{}([&](auto n0) { make_tuple(I0, I0, I0, I0),
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_block_buf,
make_tuple(n0, I0, I0, I0), b_thread_desc_,
b_block_buf, make_tuple(I0, I0, I0, I0),
b_thread_desc_, b_thread_buf);
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
// main body // main body
if constexpr(HasMainLoop) if constexpr(HasMainLoop)
...@@ -449,25 +456,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -449,25 +456,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) { a_thread_copy_.Run(
b_thread_copy_.Run( a_block_desc_m0_m1_m2_k,
b_block_desc_n0_n1_n2_k, make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), a_block_buf,
b_block_buf, a_thread_desc_,
b_thread_desc_, make_tuple(I0, I0, I0, I0),
make_tuple(n0, I0, I0, I0), a_thread_buf);
b_thread_buf);
}); b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
}); });
HotLoopScheduler(); HotLoopScheduler();
...@@ -517,24 +522,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -517,24 +522,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) { a_thread_copy_.Run(
b_thread_copy_.Run( a_block_desc_m0_m1_m2_k,
b_block_desc_n0_n1_n2_k, make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), a_block_buf,
b_block_buf, a_thread_desc_,
b_thread_desc_, make_tuple(I0, I0, I0, I0),
make_tuple(n0, I0, I0, I0), a_thread_buf);
b_thread_buf);
}); b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
}); });
HotLoopScheduler(); HotLoopScheduler();
...@@ -567,25 +571,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -567,25 +571,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) { a_thread_copy_.Run(
b_thread_copy_.Run( a_block_desc_m0_m1_m2_k,
b_block_desc_n0_n1_n2_k, make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), a_block_buf,
b_block_buf, a_thread_desc_,
b_thread_desc_, make_tuple(I0, I0, I0, I0),
make_tuple(n0, I0, I0, I0), a_thread_buf);
b_thread_buf);
}); b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
}); });
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -636,28 +638,81 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -636,28 +638,81 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I1, I1, Number<KPack>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I1, I1, Number<KPack>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType, template <bool Transpose>
ComputeDataType, struct AThreadCopySelector;
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_), template <>
Sequence<1, 1, 1, KPack>, struct AThreadCopySelector<false>
Sequence<0, 1, 2, 3>, {
3, using type = ThreadwiseTensorSliceTransfer_v5<ADataType,
A_K1, ComputeDataType,
A_K1>; decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType, Sequence<MRepeat, 1, 1, KPack>,
ComputeDataType, Sequence<0, 1, 2, 3>,
decltype(b_block_desc_n0_n1_n2_k), Sequence<0, 1, 2, 3>,
decltype(b_thread_desc_), 3,
Sequence<1, 1, 1, KPack>, 3,
Sequence<0, 1, 2, 3>, A_K1,
3, A_K1>;
B_K1, };
B_K1>;
template <>
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; struct AThreadCopySelector<true>
BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; {
using type = ThreadwiseTensorSliceTransfer_v5<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<MRepeat, 1, 1, KPack>,
Sequence<3, 1, 2, 0>,
Sequence<0, 1, 2, 3>,
0,
3,
MRepeat,
A_K1>;
};
template <bool Transpose>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_v5<BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPack>,
Sequence<2, 0, 1, 3>,
Sequence<0, 1, 2, 3>,
3,
3,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v5<BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPack>,
Sequence<3, 1, 2, 0>,
Sequence<0, 1, 2, 3>,
0,
3,
NRepeat,
B_K1>;
};
typename AThreadCopySelector<TransposeA>::type a_thread_copy_{
Base::CalculateAThreadOriginDataIndex()};
typename BThreadCopySelector<TransposeB>::type b_thread_copy_{
Base::CalculateBThreadOriginDataIndex()};
using Base::c_thread_desc_; using Base::c_thread_desc_;
}; };
......
...@@ -983,6 +983,40 @@ struct XdlopsGemm ...@@ -983,6 +983,40 @@ struct XdlopsGemm
Sequence<5>{})); Sequence<5>{}));
} }
// TransposeA
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_M2_N1_M3_N2_M4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
Number<mfma_instr.num_input_blks>{},
Number<mfma_instr.group_size>{})),
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<7>{},
Sequence<6>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2, 3, 5>{},
Sequence<4>{}));
}
// transposed XDL output supporting C' = B' * A' // transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4 // M2_N2 -> M2_N2_N3_N4
template <typename CDesc_M0_N0_M1_N1_M2_N2> template <typename CDesc_M0_N0_M1_N1_M2_N2>
......
...@@ -18,22 +18,6 @@ struct transpose_vectors; ...@@ -18,22 +18,6 @@ struct transpose_vectors;
// transpose fp16 2x2 // transpose fp16 2x2
__device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1) __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1)
{ {
#if 0
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<half_t, 2> vx0{x0}, vx1{x1};
vector_type<half_t, 2> vy0, vy1;
vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
#else
constexpr int32_t m0 = 0x05040100; constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302; constexpr int32_t m1 = 0x07060302;
...@@ -43,7 +27,6 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t ...@@ -43,7 +27,6 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
// index is reversed because of little endianness (least significant bits first) // index is reversed because of little endianness (least significant bits first)
y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0)); y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1)); y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
#endif
} }
template <index_t NX, index_t NY> template <index_t NX, index_t NY>
...@@ -83,6 +66,60 @@ struct transpose_vectors<half_t, NX, NY> ...@@ -83,6 +66,60 @@ struct transpose_vectors<half_t, NX, NY>
} }
}; };
// transpose bf16 2x2
__device__ void
transpose_bf16_2x2(const bhalf2_t& x0, const bhalf2_t& x1, bhalf2_t& y0, bhalf2_t& y1)
{
constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
y0 =
bit_cast<bhalf2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
y1 =
bit_cast<bhalf2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
}
template <index_t NX, index_t NY>
struct transpose_vectors<bhalf_t, NX, NY>
{
// we got [NY * NX] amount of S data to be transposed
static constexpr index_t s_per_x = NY;
static constexpr index_t s_per_y = NX;
using S = bhalf_t;
using VX = vector_type<bhalf_t, s_per_x>;
using VY = vector_type<bhalf_t, s_per_y>;
__device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
StaticallyIndexedArray<VY&, NY>& vy_tuple)
{
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// reference to 2 bhalf2_t data from vx_tuple
const auto& x_s2_0 = vx_tuple[ix].template AsType<bhalf2_t>()[iy / I2];
const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<bhalf2_t>()[iy / I2];
// reference to 2 bhalf2_t data from vy_tuple
auto& y_s2_0 = vy_tuple(iy).template AsType<bhalf2_t>()(ix / I2);
auto& y_s2_1 = vy_tuple(iy + I1).template AsType<bhalf2_t>()(ix / I2);
// transpose
transpose_bf16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1);
});
});
}
};
// transpose int8 4x4 // transpose int8 4x4
__device__ void transpose_int8_4x4(const int8x4_t& x0, __device__ void transpose_int8_4x4(const int8x4_t& x0,
const int8x4_t& x1, const int8x4_t& x1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment