Commit 0dda6f18 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Verify f8f6f4 MFMA Instructions

parent 0ef27d53
...@@ -784,17 +784,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8> ...@@ -784,17 +784,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4> struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
{ {
static constexpr index_t group_size = 4; // clang-format off
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 16; static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
static constexpr index_t wave_size = 64; static constexpr index_t num_threads_per_blk = 32; // n_per_blk
static constexpr index_t num_input_blks = 2; static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_output_blks = 1; static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
static constexpr index_t m_per_blk = 32; static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t n_per_blk = 32; static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 8; static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr bool is_k_reduction = true; static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
...@@ -806,17 +808,19 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4> ...@@ -806,17 +808,19 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4> struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
{ {
static constexpr index_t group_size = 4; // clang-format off
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 4; static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
static constexpr index_t wave_size = 64; static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
static constexpr index_t num_input_blks = 4; static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_output_blks = 1; static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
static constexpr index_t m_per_blk = 16; static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t n_per_blk = 16; static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 8; static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr bool is_k_reduction = true; static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
...@@ -828,17 +832,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4> ...@@ -828,17 +832,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
template <> template <>
struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4> struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
{ {
static constexpr index_t group_size = 4; // clang-format off
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 16; static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
static constexpr index_t wave_size = 64; static constexpr index_t num_threads_per_blk = 32; // n_per_blk
static constexpr index_t num_input_blks = 2; static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_output_blks = 1; static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
static constexpr index_t m_per_blk = 32; static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t n_per_blk = 32; static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 8; static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr bool is_k_reduction = true; static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
...@@ -850,17 +856,19 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4> ...@@ -850,17 +856,19 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
template <> template <>
struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4> struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
{ {
static constexpr index_t group_size = 4; // clang-format off
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 4; static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
static constexpr index_t wave_size = 64; static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
static constexpr index_t num_input_blks = 4; static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_output_blks = 1; static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
static constexpr index_t m_per_blk = 16; static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t n_per_blk = 16; static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 8; static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr bool is_k_reduction = true; static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment