Commit 5dbaf3c2 authored by Jing Zhang's avatar Jing Zhang
Browse files

refactor xdlops, hide c desc

parent 370c9245
......@@ -118,13 +118,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor()
{
///\to-do: hide xdl clayout into xdlops-gemm
constexpr auto CXdlopsLayout = xdlops_gemm.GetCXdlopsLayout();
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = Number<CXdlopsLayout.M1()>{};
constexpr auto M2 = Number<CXdlopsLayout.M0()>{};
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, I1, M2, I1));
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor()
......@@ -195,7 +196,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type<FloatAB, K1> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k0) {
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops / xdlops_gemm.KPerThread>{}([&](auto k0) {
// read A
a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
make_tuple(k0, I0, I0, I0, I0),
......@@ -212,8 +213,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(I0, I0, I0, I0, I0),
b_thread_buf);
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_per_blk>::type;
using mfma_input_type = typename vector_type<FloatAB, xdlops_gemm.KPerThread>::type;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
......
......@@ -7,7 +7,7 @@
namespace ck {
enum struct mfma_instr
enum struct MfmaInstr
{
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
......@@ -26,11 +26,11 @@ enum struct mfma_instr
mfma_f32_16x16x8bf16, // k reduction
};
template <mfma_instr instr>
struct mfma_info;
template <MfmaInstr instr>
struct mfma_type;
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
......@@ -57,7 +57,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
......@@ -84,7 +84,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
......@@ -111,7 +111,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
......@@ -139,7 +139,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
// treat 4x4x1 as a single-blk 4x64 mfma
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
......@@ -166,7 +166,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
struct mfma_type<MfmaInstr::mfma_f32_32x32x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
......@@ -193,7 +193,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
......@@ -220,7 +220,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
......@@ -247,7 +247,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
struct mfma_type<MfmaInstr::mfma_f32_16x16x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
......@@ -274,7 +274,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
......@@ -302,7 +302,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
#if 0
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
......@@ -334,7 +334,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
......@@ -365,7 +365,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
......@@ -396,7 +396,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
struct mfma_type<MfmaInstr::mfma_f32_16x16x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
......@@ -427,7 +427,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
struct mfma_type<MfmaInstr::mfma_f32_4x4x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
......@@ -458,229 +458,229 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
};
#endif
template <mfma_instr instr, index_t MPerXdlops_, index_t NPerXdlops_>
struct xdlops_info
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector
{
static constexpr auto mfma_type = mfma_info<instr>{};
static constexpr index_t MPerXdlops = MPerXdlops_;
static constexpr index_t NPerXdlops = NPerXdlops_;
static constexpr bool IsABroadcast()
{
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
return true;
}
static constexpr index_t GetKPerXdlops()
{
return mfma_type.is_k_reduction ? mfma_type.num_input_blks : 1;
}
static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; }
};
template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack>
struct XdlopsGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
template <class base_type_ = base_type,
index_t MPerWave_ = MPerWave,
index_t NPerWave_ = NPerWave>
static constexpr auto GetXdlopsInfo();
template <typename base_type_, index_t MPerXdlops_, index_t NPerXdlops_>
static constexpr auto GetMfma();
template <>
static constexpr auto GetXdlopsInfo<float, 64, 64>()
static constexpr auto GetMfma<float, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64>{};
return MfmaInstr::mfma_f32_32x32x1xf32;
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 64>()
static constexpr auto GetMfma<float, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64>{};
return MfmaInstr::mfma_f32_32x32x1xf32;
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 64>()
static constexpr auto GetMfma<float, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64>{};
return MfmaInstr::mfma_f32_16x16x1xf32;
}
template <>
static constexpr auto GetXdlopsInfo<float, 8, 64>()
static constexpr auto GetMfma<float, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64>{};
return MfmaInstr::mfma_f32_4x4x1xf32;
}
template <>
static constexpr auto GetXdlopsInfo<float, 4, 64>()
static constexpr auto GetMfma<float, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64>{};
return MfmaInstr::mfma_f32_4x4x1xf32;
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 32>()
static constexpr auto GetMfma<float, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32>{};
return MfmaInstr::mfma_f32_32x32x2xf32;
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 16>()
static constexpr auto GetMfma<float, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16>{};
return MfmaInstr::mfma_f32_16x16x4xf32;
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 64>()
static constexpr auto GetMfma<half_t, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x4f16, 64, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 64>()
static constexpr auto GetMfma<half_t, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x4f16, 32, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
static constexpr auto GetMfma<half_t, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x8f16, 32, 32>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
static constexpr auto GetMfma<half_t, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16>{};
return xdlops_info<MfmaInstr::mfma_f32_16x16x16f16, 16, 16>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 64>()
static constexpr auto GetMfma<half_t, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64>{};
return xdlops_info<MfmaInstr::mfma_f32_16x16x4f16, 16, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 8, 64>()
static constexpr auto GetMfma<half_t, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64>{};
return xdlops_info<MfmaInstr::mfma_f32_4x4x4f16, 8, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 4, 64>()
static constexpr auto GetMfma<half_t, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64>{};
return xdlops_info<MfmaInstr::mfma_f32_4x4x4f16, 4, 64>{};
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
static constexpr auto GetMfma<ushort, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 128>()
static constexpr auto GetMfma<ushort, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 64>()
static constexpr auto GetMfma<ushort, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 32>()
static constexpr auto GetMfma<ushort, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 64>()
static constexpr auto GetMfma<ushort, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 16>()
static constexpr auto GetMfma<ushort, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 64>()
static constexpr auto GetMfma<ushort, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 8, 64>()
static constexpr auto GetMfma<ushort, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 4, 64>()
static constexpr auto GetMfma<ushort, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 32>()
static constexpr auto GetMfma<ushort, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 16>()
static constexpr auto GetMfma<ushort, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
return xdlops_info<MfmaInstr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
}
#endif
using CIndex = MultiIndex<2>;
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
__device__ static constexpr index_t GetNumXdlops()
{
return MPerXdlops * NPerXdlops /
(mfma_type.m_per_blk * mfma_type.n_per_blk * mfma_type.num_output_blks);
}
__host__ __device__ static void mfma_info_check()
__host__ __device__ static constexpr void mfma_check()
{
static_assert(mfma_type.group_size * mfma_type.num_groups_per_blk == mfma_type.num_regs_per_blk,
static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
selected_mfma.num_regs_per_blk,
"wrong! num_regs_per_blk");
static_assert(mfma_type.num_threads_per_blk == mfma_type.n_per_blk,
static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
"n_per_blk != num_threads_per_blk");
static_assert(mfma_type.num_regs_per_blk * mfma_type.num_input_blks == mfma_type.m_per_blk,
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
selected_mfma.m_per_blk,
"m_per_blk != num_input_blks * num_regs_per_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
mfma_type.num_output_blks == 1,
static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
selected_mfma.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(mfma_type.num_regs_per_blk * mfma_type.wave_size ==
mfma_type.m_per_blk * mfma_type.n_per_blk,
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size ==
selected_mfma.m_per_blk * selected_mfma.n_per_blk,
"num_regs_per_blk incorrect");
static_assert(mfma_type.is_k_reduction ||
(mfma_type.num_input_blks == mfma_type.num_output_blks),
static_assert(selected_mfma.is_k_reduction ||
(selected_mfma.num_input_blks == selected_mfma.num_output_blks),
"is_k_reduction wrong!");
}
__host__ __device__ constexpr MfmaSelector() { mfma_check(); }
static constexpr bool IsABroadcast()
{
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
return true;
}
static constexpr index_t GetKPerXdlops()
{
return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) *
selected_mfma.k_per_blk;
}
static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; }
};
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack>
struct XdlopsGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>;
__device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return MPerXdlops * NPerXdlops /
(mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks);
}
__host__ __device__ constexpr XdlopsGemm()
{
static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
......@@ -690,6 +690,8 @@ struct XdlopsGemm
static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
}
template <typename CM0N0M1N1M2N2Desc>
......@@ -707,10 +709,10 @@ struct XdlopsGemm
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(mfma_type.num_groups_per_blk,
mfma_type.num_input_blks,
mfma_type.group_size)),
make_pass_through_transform(mfma_type.num_threads_per_blk)),
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
mfma_instr.num_input_blks,
mfma_instr.group_size)),
make_pass_through_transform(mfma_instr.num_threads_per_blk)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -727,7 +729,7 @@ struct XdlopsGemm
__device__ static constexpr index_t GetRegSizePerXdlops()
{
return MPerXdlops * NPerXdlops / mfma_type.wave_size;
return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
}
template <index_t c_offset, class FloatA, class FloatB, class FloatC>
......@@ -737,22 +739,20 @@ struct XdlopsGemm
is_same<base_type, ushort>::value,
"base base_type must be float, half, ushort!");
static_assert(KPack % mfma_type.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
static_for<0, KPack / mfma_type.k_per_blk, 1>{}([&](auto k) {
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[k], p_b_wave[k], p_c_thread);
});
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_type.wave_size; }
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
__device__ static auto GetBlkIdx()
{
const auto laneId = GetLaneId();
const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(1, mfma_type.num_input_blks, mfma_type.num_threads_per_blk))),
make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
......@@ -773,7 +773,7 @@ struct XdlopsGemm
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(mfma_type.is_k_reduction)
if constexpr(mfma_instr.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
......@@ -791,7 +791,7 @@ struct XdlopsGemm
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(mfma_type.is_k_reduction)
if constexpr(mfma_instr.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
......@@ -803,45 +803,29 @@ struct XdlopsGemm
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
index_t n_offset = blk_i * mfma_type.n_per_blk + blk_td;
index_t m_offset = xdlops_i * mfma_type.m_per_blk + blk_id * mfma_type.group_size;
index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td;
index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size;
return CIndex{m_offset, n_offset};
}
static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops;
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{};
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr auto mfma_instr = mfma.selected_mfma;
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto KPerThread = mfma.GetKPerThread();
struct CLayout
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
{
__host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_per_blk; }
__host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; }
__host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; }
__host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_per_blk; }
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_per_blk; }
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return MPerXdlops * NPerXdlops /
(mfma_type.m_per_blk * mfma_type.n_per_blk * mfma_type.num_output_blks);
return make_tuple(
Number<mfma_instr.num_groups_per_blk>{}, I1, Number<mfma_instr.group_size>{}, I1);
}
};
__host__ __device__ static constexpr auto GetCXdlopsLayout() { return CLayout{}; }
};
} // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment