Unverified Commit c38163cd authored by Andriy Roshchenko's avatar Andriy Roshchenko Committed by GitHub
Browse files

Test the functionality of V_MFMA_F32_16X16X128_F8F6F4 and ...

Test the functionality of V_MFMA_F32_16X16X128_F8F6F4 and  V_MFMA_F32_32X32X64_F8F6F4 instructions. (#293)

* Introduced MFMA tests

* Verified f8f6f4 MFMA Instructions
parent bcef33c1
...@@ -784,17 +784,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8> ...@@ -784,17 +784,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4> struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
{ {
static constexpr index_t group_size = 4; // clang-format off
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 16; static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
static constexpr index_t wave_size = 64; static constexpr index_t num_threads_per_blk = 32; // n_per_blk
static constexpr index_t num_input_blks = 2; static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_output_blks = 1; static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
static constexpr index_t m_per_blk = 32; static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t n_per_blk = 32; static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 8; static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr bool is_k_reduction = true; static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
...@@ -806,17 +808,19 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4> ...@@ -806,17 +808,19 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4> struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
{ {
static constexpr index_t group_size = 4; // clang-format off
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 4; static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
static constexpr index_t wave_size = 64; static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
static constexpr index_t num_input_blks = 4; static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_output_blks = 1; static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
static constexpr index_t m_per_blk = 16; static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t n_per_blk = 16; static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 8; static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr bool is_k_reduction = true; static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
...@@ -828,17 +832,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4> ...@@ -828,17 +832,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
template <> template <>
struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4> struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
{ {
static constexpr index_t group_size = 4; // clang-format off
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 16; static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
static constexpr index_t wave_size = 64; static constexpr index_t num_threads_per_blk = 32; // n_per_blk
static constexpr index_t num_input_blks = 2; static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_output_blks = 1; static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
static constexpr index_t m_per_blk = 32; static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t n_per_blk = 32; static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 8; static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr bool is_k_reduction = true; static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
...@@ -850,17 +856,19 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4> ...@@ -850,17 +856,19 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
template <> template <>
struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4> struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
{ {
static constexpr index_t group_size = 4; // clang-format off
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 4; static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
static constexpr index_t wave_size = 64; static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
static constexpr index_t num_input_blks = 4; static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_output_blks = 1; static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
static constexpr index_t m_per_blk = 16; static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t n_per_blk = 16; static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 8; static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr bool is_k_reduction = true; static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
......
...@@ -476,24 +476,33 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -476,24 +476,33 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
// TODO: fix ...f8f6f4 instructions
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x64f8f6f4; struct intrin_mfma_f32_32x32x64f8f6f4;
/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6,
/// and f4 data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template <> template <>
struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x64_f8f6f4( reg_c.template AsType<float16_t>()(Number<0>{}) =
bit_cast<long>(reg_a), __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
bit_cast<long>(reg_b), reg_a,
reg_c.template AsType<float16_t>()[Number<0>{}], reg_b,
0, reg_c.template AsType<float16_t>()[Number<0>{}],
0, 0, // cbsz
0); 0, // blgp
0,
0,
0,
0);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = reg_b; ignore = reg_b;
...@@ -509,20 +518,30 @@ template <> ...@@ -509,20 +518,30 @@ template <>
struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
bit_cast<long>(reg_a), reg_a,
bit_cast<long>(reg_b), reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}], reg_c.template AsType<float16_t>()[Number<0>{}],
0, 0, // cbsz
0, 0, // blgp
0); 0, // { OPSEL_HI[0], OPSEL[0] }?
scale_a,
0, // { OPSEL_HI[1], OPSEL[1] }?
scale_b);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = scale_a;
ignore = reg_b; ignore = reg_b;
ignore = scale_b;
ignore = reg_c; ignore = reg_c;
#endif #endif
} }
...@@ -535,20 +554,30 @@ template <> ...@@ -535,20 +554,30 @@ template <>
struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
bit_cast<long>(reg_a), reg_a,
bit_cast<long>(reg_b), reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}], reg_c.template AsType<float4_t>()[Number<0>{}],
0, 0, // cbsz
0, 0, // blgp
0); 0, // { OPSEL_HI[0], OPSEL[0] }?
scale_a,
0, // { OPSEL_HI[1], OPSEL[1] }?
scale_b);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = scale_a;
ignore = reg_b; ignore = reg_b;
ignore = scale_b;
ignore = reg_c; ignore = reg_c;
#endif #endif
} }
...@@ -557,20 +586,31 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> ...@@ -557,20 +586,31 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x128f8f6f4; struct intrin_mfma_f32_16x16x128f8f6f4;
/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4
/// data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template <> template <>
struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x128_f8f6f4( // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
bit_cast<long>(reg_a), reg_c.template AsType<float4_t>()(Number<0>{}) =
bit_cast<long>(reg_b), __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_c.template AsType<float4_t>()[Number<0>{}], reg_a,
0, reg_b,
0, reg_c.template AsType<float4_t>()[Number<0>{}],
0); 0, // cbsz
0, // blgp
0,
0,
0,
0);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = reg_b; ignore = reg_b;
......
...@@ -169,18 +169,28 @@ function(add_gtest_executable TEST_NAME) ...@@ -169,18 +169,28 @@ function(add_gtest_executable TEST_NAME)
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
message("removing xdl test ${source} ") message("removing xdl test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
foreach(source IN LISTS ARGN)
if(NOT TEST_TARGETS MATCHES "gfx95" AND source MATCHES "mx_")
message("removing microscaling test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT TEST_TARGETS MATCHES "gfx11" AND NOT TEST_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") if(NOT TEST_TARGETS MATCHES "gfx11" AND NOT TEST_TARGETS MATCHES "gfx12" AND source MATCHES "wmma")
message("removing wmma test ${source} ") message("removing wmma test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(ARGN) if(ARGN)
if(ARGN MATCHES "_xdl") if(ARGN MATCHES "_xdl")
...@@ -189,6 +199,8 @@ function(add_gtest_executable TEST_NAME) ...@@ -189,6 +199,8 @@ function(add_gtest_executable TEST_NAME)
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
elseif(ARGN MATCHES "_smfmac") elseif(ARGN MATCHES "_smfmac")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(ARGN MATCHES "_mx") #only build mx example for gfx950
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
endif() endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
...@@ -261,5 +273,8 @@ endif() ...@@ -261,5 +273,8 @@ endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") # smfmac needs ROCm6.2 if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") # smfmac needs ROCm6.2
add_subdirectory(smfmac_op) add_subdirectory(smfmac_op)
endif() endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx950")
add_subdirectory(mx_mfma_op)
endif()
add_subdirectory(position_embedding) add_subdirectory(position_embedding)
add_subdirectory(scatter_gather) add_subdirectory(scatter_gather)
add_custom_target(test_mx_mfma)
add_gtest_executable(test_mx_mfma_op mx_mfma_op.cpp)
if(result EQUAL 0)
target_link_libraries(test_mx_mfma_op PRIVATE utility)
endif()
add_dependencies(test_mx_mfma test_mx_mfma_op)
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "mx_mfma_op.hpp"
using ck::e8m0_bexp_t;
using ck::f8_t;
using ck::half_t;
using ck::type_convert;
/**
* @brief Run the test for the given MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mfma_test(ck::index_t init)
{
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
using AccType = float; // only MFMA_F32 instructions supported
using CPUAccType = AccType;
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
const auto mx_mfma_kernel = ck::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
bool pass = true;
pass = ck::mfma_test::TestMFMA<decltype(mx_mfma_kernel),
AType,
BType,
CType,
AccType,
CPUAccType,
ALayout,
BLayout,
CLayout,
BLOCK_M,
BLOCK_N,
BLOCK_K>{}(mx_mfma_kernel, init);
return pass;
}
TEST(MFMA, FP8MFMA16x16x128)
{
auto AB_init = 0;
auto pass = run_mfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
{
auto AB_init = 0;
auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment