"magic_pdf/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "51f56aa32f5e8c48f196fd24146ea364507326ff"
Commit 0dda6f18 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Verify f8f6f4 MFMA Instructions

parent 0ef27d53
......@@ -784,17 +784,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 32; // n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
......@@ -806,17 +808,19 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
......@@ -828,17 +832,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
template <>
struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 32; // n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
......@@ -850,17 +856,19 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
template <>
struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment