Commit 370c9245 authored by Jing Zhang's avatar Jing Zhang
Browse files

change mfma_info

parent c982e753
...@@ -35,8 +35,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -35,8 +35,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr auto CXdlopsLayout = xdlops_gemm.GetCXdlopsLayout();
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
...@@ -116,15 +114,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -116,15 +114,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!"); "wrong!");
constexpr index_t NumBlks = CXdlopsLayout.GetNumBlks();
constexpr index_t NumXdlops = CXdlopsLayout.GetNumXdlops();
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
} }
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor() __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor()
{ {
///\to-do: hide xdl clayout into xdlops-gemm
constexpr auto CXdlopsLayout = xdlops_gemm.GetCXdlopsLayout();
constexpr auto M0 = Number<CXdlopsLayout.M1()>{}; constexpr auto M0 = Number<CXdlopsLayout.M1()>{};
constexpr auto M2 = Number<CXdlopsLayout.M0()>{}; constexpr auto M2 = Number<CXdlopsLayout.M0()>{};
......
...@@ -34,10 +34,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32> ...@@ -34,10 +34,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2; static constexpr index_t num_output_blks = 2;
static constexpr index_t m_per_blk = 32; static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32; static constexpr index_t n_per_blk = 32;
...@@ -61,10 +61,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32> ...@@ -61,10 +61,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32; static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32; static constexpr index_t n_per_blk = 32;
...@@ -88,10 +88,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32> ...@@ -88,10 +88,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16; static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16; static constexpr index_t n_per_blk = 16;
...@@ -115,10 +115,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32> ...@@ -115,10 +115,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4; static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16; static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16; static constexpr index_t n_per_blk = 16;
...@@ -143,7 +143,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32> ...@@ -143,7 +143,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64; static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1; static constexpr index_t num_input_blks = 1;
...@@ -170,10 +170,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16> ...@@ -170,10 +170,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2; static constexpr index_t num_output_blks = 2;
static constexpr index_t m_per_blk = 32; static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32; static constexpr index_t n_per_blk = 32;
...@@ -197,10 +197,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16> ...@@ -197,10 +197,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32; static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32; static constexpr index_t n_per_blk = 32;
...@@ -224,10 +224,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16> ...@@ -224,10 +224,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16; static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16; static constexpr index_t n_per_blk = 16;
...@@ -251,10 +251,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16> ...@@ -251,10 +251,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4; static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16; static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16; static constexpr index_t n_per_blk = 16;
...@@ -278,7 +278,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16> ...@@ -278,7 +278,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64; static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1; static constexpr index_t num_input_blks = 1;
...@@ -306,10 +306,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16> ...@@ -306,10 +306,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2; static constexpr index_t num_output_blks = 2;
static constexpr index_t m_per_blk = 32; static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32; static constexpr index_t n_per_blk = 32;
...@@ -338,10 +338,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16> ...@@ -338,10 +338,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32; static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32; static constexpr index_t n_per_blk = 32;
...@@ -369,10 +369,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16> ...@@ -369,10 +369,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16; static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16; static constexpr index_t n_per_blk = 16;
...@@ -400,10 +400,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16> ...@@ -400,10 +400,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4; static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16; static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16; static constexpr index_t n_per_blk = 16;
...@@ -431,7 +431,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16> ...@@ -431,7 +431,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64; static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1; static constexpr index_t num_input_blks = 1;
...@@ -659,6 +659,8 @@ struct XdlopsGemm ...@@ -659,6 +659,8 @@ struct XdlopsGemm
__host__ __device__ static void mfma_info_check() __host__ __device__ static void mfma_info_check()
{ {
static_assert(mfma_type.group_size * mfma_type.num_groups_per_blk == mfma_type.num_regs_per_blk,
"wrong! num_regs_per_blk");
static_assert(mfma_type.num_threads_per_blk == mfma_type.n_per_blk, static_assert(mfma_type.num_threads_per_blk == mfma_type.n_per_blk,
"n_per_blk != num_threads_per_blk"); "n_per_blk != num_threads_per_blk");
...@@ -745,8 +747,9 @@ struct XdlopsGemm ...@@ -745,8 +747,9 @@ struct XdlopsGemm
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_type.wave_size; } __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_type.wave_size; }
__device__ static auto GetBlkIdx(const index_t laneId) __device__ static auto GetBlkIdx()
{ {
const auto laneId = GetLaneId();
const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform( make_tuple(make_merge_transform(
make_tuple(1, mfma_type.num_input_blks, mfma_type.num_threads_per_blk))), make_tuple(1, mfma_type.num_input_blks, mfma_type.num_threads_per_blk))),
...@@ -765,7 +768,7 @@ struct XdlopsGemm ...@@ -765,7 +768,7 @@ struct XdlopsGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex() __host__ __device__ static auto CalculateAThreadOriginDataIndex()
{ {
const auto laneId = GetLaneId(); const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx(laneId); const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0]; const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1]; const auto blk_td = blk_idx[I1];
...@@ -783,7 +786,7 @@ struct XdlopsGemm ...@@ -783,7 +786,7 @@ struct XdlopsGemm
__host__ __device__ static auto CalculateBThreadOriginDataIndex() __host__ __device__ static auto CalculateBThreadOriginDataIndex()
{ {
const auto laneId = GetLaneId(); const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx(laneId); const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0]; const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1]; const auto blk_td = blk_idx[I1];
...@@ -801,7 +804,7 @@ struct XdlopsGemm ...@@ -801,7 +804,7 @@ struct XdlopsGemm
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{ {
const auto laneId = GetLaneId(); const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx(laneId); const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0]; const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1]; const auto blk_td = blk_idx[I1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment