Commit 370c9245 authored by Jing Zhang's avatar Jing Zhang
Browse files

change mfma_info

parent c982e753
......@@ -35,8 +35,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr auto CXdlopsLayout = xdlops_gemm.GetCXdlopsLayout();
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
......@@ -116,15 +114,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
constexpr index_t NumBlks = CXdlopsLayout.GetNumBlks();
constexpr index_t NumXdlops = CXdlopsLayout.GetNumXdlops();
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
}
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor()
{
///\to-do: hide xdl clayout into xdlops-gemm
constexpr auto CXdlopsLayout = xdlops_gemm.GetCXdlopsLayout();
constexpr auto M0 = Number<CXdlopsLayout.M1()>{};
constexpr auto M2 = Number<CXdlopsLayout.M0()>{};
......
......@@ -34,10 +34,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
......@@ -61,10 +61,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
......@@ -88,10 +88,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
......@@ -115,10 +115,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
......@@ -143,7 +143,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
......@@ -170,10 +170,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
......@@ -197,10 +197,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
......@@ -224,10 +224,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
......@@ -251,10 +251,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
......@@ -278,7 +278,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
......@@ -306,10 +306,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
......@@ -338,10 +338,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
......@@ -369,10 +369,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
......@@ -400,10 +400,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
......@@ -431,7 +431,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
......@@ -659,6 +659,8 @@ struct XdlopsGemm
__host__ __device__ static void mfma_info_check()
{
static_assert(mfma_type.group_size * mfma_type.num_groups_per_blk == mfma_type.num_regs_per_blk,
"wrong! num_regs_per_blk");
static_assert(mfma_type.num_threads_per_blk == mfma_type.n_per_blk,
"n_per_blk != num_threads_per_blk");
......@@ -745,8 +747,9 @@ struct XdlopsGemm
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_type.wave_size; }
__device__ static auto GetBlkIdx(const index_t laneId)
__device__ static auto GetBlkIdx()
{
const auto laneId = GetLaneId();
const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(1, mfma_type.num_input_blks, mfma_type.num_threads_per_blk))),
......@@ -765,7 +768,7 @@ struct XdlopsGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx(laneId);
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
......@@ -783,7 +786,7 @@ struct XdlopsGemm
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx(laneId);
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
......@@ -801,7 +804,7 @@ struct XdlopsGemm
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx(laneId);
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment