Commit 72a488cc authored by aska-0096's avatar aska-0096
Browse files

All layout sanity pass

parent ea41fc2f
......@@ -154,12 +154,12 @@ struct BlockwiseGemmXdlops_pipeline_base
return make_tuple(c_thread_m, c_thread_n);
}
// Contiguous output tile
// NContiguous output tile
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto CalculateCThreadOriginDataIndexContiguous(Number<m0>,
Number<n0>,
Number<xdlops_i>,
Number<blk_i>)
__device__ static auto CalculateCThreadOriginDataIndexNContiguous(Number<m0>,
Number<n0>,
Number<xdlops_i>,
Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
......@@ -186,6 +186,38 @@ struct BlockwiseGemmXdlops_pipeline_base
return make_tuple(c_thread_m, c_thread_n);
}
// MNContiguous output tile
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto CalculateCThreadOriginDataIndexMNContiguous(Number<m0>,
Number<n0>,
Number<xdlops_i>,
Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MWaves, MPerXDL, MRepeat))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NWaves, NPerXDL, NRepeat))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(waveId_m, blk_idx[I0], m0))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(waveId_n, blk_idx[I1], n0))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
......@@ -246,19 +278,30 @@ struct BlockwiseGemmXdlops_pipeline_base
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// Contiguous output tile
// N-Contiguous output tile
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, Number<MRepeat>{}, I1, I1, M0, M1, N, Number<NRepeat>{}, M2));
make_tuple(I1, I1, Number<MRepeat>{}, I1, I1, M0, I1, I1, Number<NRepeat>{}, M2));
}
// MN-Contiguous output tile
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MBlock_NBlock_M0_N0_M1_M2_N1_M3_N2_M4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, I1, M0, I1, I1, Number<MRepeat>{}, Number<NRepeat>{}, M2));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
......@@ -315,6 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_base
return xdlops_gemm.MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// TransposeA
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_M2_N1_M3_N2_M4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_M2_N1_M3_N2_M4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
......@@ -435,7 +492,7 @@ struct BlockwiseGemmXdlops_pipeline_base
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
Sequence<2, 0, 1, 3>,
Sequence<0, 1, 2, 3>,
3,
3,
......
......@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPacks>
index_t KPacks,
bool TransposeA,
bool TransposeB>
struct BlockwiseGemmXdlops_pipeline_v4
{
};
......@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack
index_t KPack,
bool TransposeA,
bool TransposeB
// ,bool TransposeC //disable transposec right now...
>
struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
......@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeA,
TransposeB>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
......@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeA,
TransposeB>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
......@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
KPack,
TransposeA,
TransposeB>;
using Base::I0;
using Base::I1;
using Base::KRepeat;
......@@ -298,22 +308,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(I0));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(I0));
});
});
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(I0, I0, k, I0),
a_thread_bufs(I0));
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0),
b_thread_desc_,
make_tuple(I0, I0, k, I0),
b_thread_bufs(I0));
});
// Global prefetch 3
......@@ -349,23 +355,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(I0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_thread_desc_,
make_tuple(I0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
a_blockwise_copy.RunWrite(
......@@ -430,22 +431,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(I0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_thread_desc_,
make_tuple(I0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
......@@ -489,22 +486,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(I0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(I0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_thread_desc_,
make_tuple(I0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
......
......@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPacks>
index_t KPacks,
bool TransposeA,
bool TransposeB>
struct BlockwiseGemmXdlops_pipeline_v5
{
};
......@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack
index_t KPack,
bool TransposeA,
bool TransposeB
// ,bool TransposeC //disable transposec right now...
>
struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
......@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeA,
TransposeB>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
......@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeA,
TransposeB>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
......@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
KPack,
TransposeA,
TransposeB>;
using Base::A_K1;
using Base::B_K1;
using Base::I0;
......@@ -381,22 +391,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// main body
if constexpr(HasMainLoop)
......@@ -449,25 +456,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
});
HotLoopScheduler();
......@@ -517,24 +522,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
});
HotLoopScheduler();
......@@ -567,25 +571,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
......@@ -636,28 +638,81 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I1, I1, Number<KPack>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
template <bool Transpose>
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_v5<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<MRepeat, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
3,
3,
A_K1,
A_K1>;
};
template <>
struct AThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v5<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<MRepeat, 1, 1, KPack>,
Sequence<3, 1, 2, 0>,
Sequence<0, 1, 2, 3>,
0,
3,
MRepeat,
A_K1>;
};
template <bool Transpose>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_v5<BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPack>,
Sequence<2, 0, 1, 3>,
Sequence<0, 1, 2, 3>,
3,
3,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v5<BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPack>,
Sequence<3, 1, 2, 0>,
Sequence<0, 1, 2, 3>,
0,
3,
NRepeat,
B_K1>;
};
typename AThreadCopySelector<TransposeA>::type a_thread_copy_{
Base::CalculateAThreadOriginDataIndex()};
typename BThreadCopySelector<TransposeB>::type b_thread_copy_{
Base::CalculateBThreadOriginDataIndex()};
using Base::c_thread_desc_;
};
......
......@@ -235,25 +235,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
Number<MNWaves>{}, Number<MNPerXdl>{}, Number<MNXdlPerWave>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<1, 2, 0>{}));
#if 0
constexpr auto mma_transformed =
transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNWaves>{}, Number<MNPerXdl>{}, Number<MNXdlPerWave>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor(
mma_transformed,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_pass_through_transform(Number<MNWaves>{}),
make_pass_through_transform(Number<MNPerXdl>{}),
make_pass_through_transform(Number<MNXdlPerWave>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<0>{}));
#endif
}
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
......@@ -448,20 +429,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto b_mma_desc = [&]() {
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return MakeGemmMmaTileDescriptorCongruous<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
}();
return b_mma_desc;
return MakeGemmMmaTileDescriptorCongruous<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
__host__ __device__ static auto
......@@ -484,45 +453,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
struct Problem
......@@ -723,92 +653,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else // ColumnMajor A
{
#if 0
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / MPerXdl;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
? 1
: 128 / (AK1Number * M0 * sizeof(ADataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=n0
constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
? 1
: ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
? M0
: 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * M1>{},
Number<kfold * M0 / mpair>{},
Number<mpair>{},
AK1Number));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<M1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<M0 / mpair>{}, Number<mpair>{}, Number<M1>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
#endif
static_assert(ABlockTransferSrcScalarPerVector % MXdlPerWave == 0);
return make_naive_tensor_descriptor(
......@@ -867,89 +711,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else // RowMajor B
{
#if 0
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / NPerXdl;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
? 1
: 128 / (BK1Number * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
? 1
: ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
? N0
: 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * N1>{},
Number<kfold * N0 / npair>{},
Number<npair>{},
BK1Number));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<N1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<N0 / npair>{}, Number<npair>{}, Number<N1>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_lds_block_desc_bk0_n_bk1;
#endif
static_assert(BBlockTransferSrcScalarPerVector % NXdlPerWave == 0);
return make_naive_tensor_descriptor(
......@@ -958,19 +719,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
__device__ static constexpr auto GetCThreadDescriptor_MBlock_NBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return BlockwiseGemmPipe::GetCThreadDescriptor_MBlock_NBlock_M0_N0_M1_M2_N1_M3_N2_M4();
}
else
{
return BlockwiseGemmPipe::GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4();
}
}
using BlockwiseGemmPipe =
......@@ -1413,88 +1171,188 @@ struct GridwiseGemm_xdl_cshuffle_v3
num_k_block_main_loop);
// Epilogue
constexpr auto c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4();
constexpr auto c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4();
constexpr auto M0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<0>{});
constexpr auto M1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<1>{});
constexpr auto N0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<2>{});
constexpr auto M2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<3>{});
constexpr auto M3 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<4>{});
constexpr auto N1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<5>{});
constexpr auto N2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 = transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<2, 3, 5, 6, 9>{}, Sequence<1>{}, Sequence<4, 7, 8>{}));
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexContiguous(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Typecast -> Permute -> Coalesced vector store
auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4<
AccDataType,
CDataType,
decltype(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
decltype(c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
CElementwiseOperation,
Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 9, 8>,
9,
8,
M4,
N2,
InMemoryDataOperationEnum::Set,
false>{c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
make_multi_index(block_m_id,
block_n_id,
m_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I4]),
c_element_op};
c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
constexpr auto c_thread_desc_mblock_nblock = GetCThreadDescriptor_MBlock_NBlock();
auto c_block_trait = [&]() {
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
constexpr auto c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_M2_N1_M3_N2_M4();
constexpr auto M0 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<0>{});
constexpr auto N0 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<1>{});
constexpr auto M1 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<2>{});
constexpr auto M2 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<3>{});
constexpr auto N1 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<4>{});
constexpr auto M3 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<5>{});
constexpr auto N2 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4 =
transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M4, M3)),
make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{},
Sequence<2, 4, 5, 9, 7>{},
Sequence<1>{},
Sequence<3, 6, 8>{}));
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexMNContiguous(
I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Typecast -> Permute -> Coalesced vector store
auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4<
AccDataType,
CDataType,
decltype(c_thread_desc_mblock_nblock),
decltype(c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4),
CElementwiseOperation,
Sequence<I1, I1, I1, I1, M1, I1, I1, M3, N2, M4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<0, 1, 2, 3, 4, 5, 6, 8, 9, 7>,
9,
8,
M4,
N2,
InMemoryDataOperationEnum::Set,
false>{c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4,
make_multi_index(block_m_id,
block_n_id,
m_thread_data_on_block_idx[I0],
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I4]),
c_element_op};
return make_tuple(c_thread_copy_vgpr_to_global,
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4);
}
else
{
constexpr auto c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4();
constexpr auto M0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<0>{});
constexpr auto M1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<1>{});
constexpr auto N0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<2>{});
constexpr auto M2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<3>{});
constexpr auto M3 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<4>{});
constexpr auto N1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<5>{});
constexpr auto N2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 =
transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{},
Sequence<2, 3, 5, 6, 9>{},
Sequence<1>{},
Sequence<4, 7, 8>{}));
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexNContiguous(
I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Typecast -> Permute -> Coalesced vector store
auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4<
AccDataType,
CDataType,
decltype(c_thread_desc_mblock_nblock),
decltype(c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
CElementwiseOperation,
Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 9, 8>,
9,
8,
M4,
N2,
InMemoryDataOperationEnum::Set,
false>{c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
make_multi_index(block_m_id,
block_n_id,
m_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I4]),
c_element_op};
return make_tuple(c_thread_copy_vgpr_to_global,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4);
}
}();
auto c_thread_copy_vgpr_to_global = c_block_trait.At(Number<0>{});
auto c_grid_desc_mblock_nblock = c_block_trait.At(Number<1>{});
c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
c_grid_desc_mblock_nblock,
c_grid_buf);
}
......@@ -1607,11 +1465,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockTransferSrcScalarPerVector,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
......@@ -1638,11 +1496,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockTransferSrcScalarPerVector,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
......@@ -1706,88 +1564,188 @@ struct GridwiseGemm_xdl_cshuffle_v3
num_k_block_main_loop);
// Epilogue
constexpr auto c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4();
constexpr auto c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4();
constexpr auto M0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<0>{});
constexpr auto M1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<1>{});
constexpr auto N0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<2>{});
constexpr auto M2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<3>{});
constexpr auto M3 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<4>{});
constexpr auto N1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<5>{});
constexpr auto N2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 = transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<2, 3, 5, 6, 9>{}, Sequence<1>{}, Sequence<4, 7, 8>{}));
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexContiguous(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Typecast -> Permute -> Coalesced vector store
auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4<
AccDataType,
CDataType,
decltype(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
decltype(c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
CElementwiseOperation,
Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 9, 8>,
9,
8,
M4,
N2,
InMemoryDataOperationEnum::Set,
false>{c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
make_multi_index(block_m_id,
block_n_id,
m_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I4]),
c_element_op};
c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
constexpr auto c_thread_desc_mblock_nblock = GetCThreadDescriptor_MBlock_NBlock();
auto c_block_trait = [&]() {
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
constexpr auto c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_M2_N1_M3_N2_M4();
constexpr auto M0 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<0>{});
constexpr auto N0 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<1>{});
constexpr auto M1 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<2>{});
constexpr auto M2 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<3>{});
constexpr auto N1 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<4>{});
constexpr auto M3 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<5>{});
constexpr auto N2 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_n0_m1_m2_n1_m3_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4 =
transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M4, M3)),
make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{},
Sequence<2, 4, 5, 9, 7>{},
Sequence<1>{},
Sequence<3, 6, 8>{}));
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexMNContiguous(
I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Typecast -> Permute -> Coalesced vector store
auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4<
AccDataType,
CDataType,
decltype(c_thread_desc_mblock_nblock),
decltype(c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4),
CElementwiseOperation,
Sequence<I1, I1, I1, I1, M1, I1, I1, M3, N2, M4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<0, 1, 2, 3, 4, 5, 6, 8, 9, 7>,
9,
8,
M4,
N2,
InMemoryDataOperationEnum::Set,
false>{c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4,
make_multi_index(block_m_id,
block_n_id,
m_thread_data_on_block_idx[I0],
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I4]),
c_element_op};
return make_tuple(c_thread_copy_vgpr_to_global,
c_grid_desc_mblock_nblock_m0_n0_m1_m2_n1_m3_n2_m4);
}
else
{
constexpr auto c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4();
constexpr auto M0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<0>{});
constexpr auto M1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<1>{});
constexpr auto N0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<2>{});
constexpr auto M2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<3>{});
constexpr auto M3 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<4>{});
constexpr auto N1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<5>{});
constexpr auto N2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 =
transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{},
Sequence<2, 3, 5, 6, 9>{},
Sequence<1>{},
Sequence<4, 7, 8>{}));
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexNContiguous(
I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Typecast -> Permute -> Coalesced vector store
auto c_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r4<
AccDataType,
CDataType,
decltype(c_thread_desc_mblock_nblock),
decltype(c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
CElementwiseOperation,
Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 9, 8>,
9,
8,
M4,
N2,
InMemoryDataOperationEnum::Set,
false>{c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
make_multi_index(block_m_id,
block_n_id,
m_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I4]),
c_element_op};
return make_tuple(c_thread_copy_vgpr_to_global,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4);
}
}();
auto c_thread_copy_vgpr_to_global = c_block_trait.At(Number<0>{});
auto c_grid_desc_mblock_nblock = c_block_trait.At(Number<1>{});
c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
c_grid_desc_mblock_nblock,
c_grid_buf);
}
......
......@@ -983,6 +983,40 @@ struct XdlopsGemm
Sequence<5>{}));
}
// TransposeA
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_M2_N1_M3_N2_M4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
Number<mfma_instr.num_input_blks>{},
Number<mfma_instr.group_size>{})),
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<7>{},
Sequence<6>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2, 3, 5>{},
Sequence<4>{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template <typename CDesc_M0_N0_M1_N1_M2_N2>
......
......@@ -18,22 +18,6 @@ struct transpose_vectors;
// transpose fp16 2x2
__device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1)
{
#if 0
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<half_t, 2> vx0{x0}, vx1{x1};
vector_type<half_t, 2> vy0, vy1;
vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
#else
constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302;
......@@ -43,7 +27,6 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
// index is reversed because of little endianness (least significant bits first)
y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
#endif
}
template <index_t NX, index_t NY>
......@@ -83,6 +66,60 @@ struct transpose_vectors<half_t, NX, NY>
}
};
// transpose bf16 2x2
__device__ void
transpose_bf16_2x2(const bhalf2_t& x0, const bhalf2_t& x1, bhalf2_t& y0, bhalf2_t& y1)
{
constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
y0 =
bit_cast<bhalf2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
y1 =
bit_cast<bhalf2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
}
template <index_t NX, index_t NY>
struct transpose_vectors<bhalf_t, NX, NY>
{
// we got [NY * NX] amount of S data to be transposed
static constexpr index_t s_per_x = NY;
static constexpr index_t s_per_y = NX;
using S = bhalf_t;
using VX = vector_type<bhalf_t, s_per_x>;
using VY = vector_type<bhalf_t, s_per_y>;
__device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
StaticallyIndexedArray<VY&, NY>& vy_tuple)
{
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// reference to 2 bhalf2_t data from vx_tuple
const auto& x_s2_0 = vx_tuple[ix].template AsType<bhalf2_t>()[iy / I2];
const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<bhalf2_t>()[iy / I2];
// reference to 2 bhalf2_t data from vy_tuple
auto& y_s2_0 = vy_tuple(iy).template AsType<bhalf2_t>()(ix / I2);
auto& y_s2_1 = vy_tuple(iy + I1).template AsType<bhalf2_t>()(ix / I2);
// transpose
transpose_bf16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1);
});
});
}
};
// transpose int8 4x4
__device__ void transpose_int8_4x4(const int8x4_t& x0,
const int8x4_t& x1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment