Commit 3973caa4 authored by illsilin's avatar illsilin
Browse files

switch between intrinsic mfma routines on mi100/200 and mi300

parent dc58fa9a
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
namespace ck { namespace ck {
#if (defined(__gfx908__) || defined(__gfx90a__))
enum struct MfmaInstr enum struct MfmaInstr
{ {
mfma_f32_32x32x1xf32 = 0, mfma_f32_32x32x1xf32 = 0,
...@@ -29,6 +30,28 @@ enum struct MfmaInstr ...@@ -29,6 +30,28 @@ enum struct MfmaInstr
mfma_i32_16x16x16i8, mfma_i32_16x16x16i8,
mfma_f64_16x16x4f64 mfma_f64_16x16x4f64
}; };
#elif (defined(__gfx940__))
enum struct MfmaInstr
{
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32,
mfma_f32_16x16x4xf32,
mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16,
mfma_f32_16x16x16f16,
mfma_f32_32x32x8bf16_1k,
mfma_f32_16x16x16bf16_1k,
mfma_f32_32x32x4bf16,
mfma_f32_16x16x8bf16,
mfma_i32_32x32x16i8,
mfma_i32_16x16x16i8,
mfma_f64_16x16x4f64
};
#endif
template <MfmaInstr instr> template <MfmaInstr instr>
struct mfma_type; struct mfma_type;
...@@ -342,6 +365,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16> ...@@ -342,6 +365,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
} }
}; };
#if (defined(__gfx908__) || defined(__gfx90a__))
template <> template <>
struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8> struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8>
{ {
...@@ -363,6 +387,29 @@ struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8> ...@@ -363,6 +387,29 @@ struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8>
intrin_mfma_i32_32x32x8i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_i32_32x32x8i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#elif (defined(__gfx940__))
template <>
struct mfma_type<MfmaInstr::mfma_i32_32x32x16i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_i32_32x32x16i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
template <> template <>
struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8> struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
...@@ -524,11 +571,19 @@ struct MfmaSelector ...@@ -524,11 +571,19 @@ struct MfmaSelector
#endif #endif
} }
#if (defined(__gfx908__) || defined(__gfx90a__))
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() static constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x8i8; return MfmaInstr::mfma_i32_32x32x8i8;
} }
#elif (defined(__gfx940__))
template <>
static constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x16i8;
}
#endif
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() static constexpr auto GetMfma<int8_t, 16, 16>()
......
...@@ -259,6 +259,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16> ...@@ -259,6 +259,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
} }
}; };
#if (defined(__gfx908__) || defined(__gfx90a__))
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x8i8; struct intrin_mfma_i32_32x32x8i8;
...@@ -277,6 +278,26 @@ struct intrin_mfma_i32_32x32x8i8<32, 32> ...@@ -277,6 +278,26 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
0); 0);
} }
}; };
#elif (defined(__gfx940__))
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x16i8;
template <>
struct intrin_mfma_i32_32x32x16i8<32, 32>
{
template <class FloatC>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int32_t>(reg_a),
bit_cast<int32_t>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}],
0,
0,
0);
}
};
#endif
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x16i8; struct intrin_mfma_i32_16x16x16i8;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment