Commit 6778c318 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

WIP: Introduce MX MFMA test

parent c4a05057
...@@ -126,18 +126,28 @@ function(add_gtest_executable TEST_NAME) ...@@ -126,18 +126,28 @@ function(add_gtest_executable TEST_NAME)
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
message("removing xdl test ${source} ") message("removing xdl test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
foreach(source IN LISTS ARGN)
if(NOT TEST_TARGETS MATCHES "gfx95" AND source MATCHES "mx_")
message("removing microscaling test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT TEST_TARGETS MATCHES "gfx11" AND NOT TEST_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") if(NOT TEST_TARGETS MATCHES "gfx11" AND NOT TEST_TARGETS MATCHES "gfx12" AND source MATCHES "wmma")
message("removing wmma test ${source} ") message("removing wmma test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(ARGN) if(ARGN)
if(ARGN MATCHES "_xdl") if(ARGN MATCHES "_xdl")
...@@ -209,5 +219,8 @@ endif() ...@@ -209,5 +219,8 @@ endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") # smfmac needs ROCm6.2 if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") # smfmac needs ROCm6.2
add_subdirectory(smfmac_op) add_subdirectory(smfmac_op)
endif() endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx950")
add_subdirectory(mx_mfma_op)
endif()
add_subdirectory(position_embedding) add_subdirectory(position_embedding)
add_subdirectory(scatter_gather) add_subdirectory(scatter_gather)
add_custom_target(test_mx_mfma)
add_gtest_executable(test_mx_mfma_op mx_mfma_op.cpp)
if(result EQUAL 0)
target_link_libraries(test_mx_mfma_op PRIVATE utility)
endif()
add_dependencies(test_mx_mfma test_mx_mfma_op)
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "mx_mfma_op.hpp"
using ck::e8m0_bexp_t;
using ck::f8_ocp_t;
using ck::type_convert;
template <typename Src1Type,
ck::index_t Src1VecSize,
typename Src2Type,
ck::index_t Src2VecSize,
typename DstType,
ck::index_t AccVecSize,
typename AccType,
typename CPUAccType,
ck::index_t M,
ck::index_t N,
ck::index_t K>
bool run_test()
{
using Row = ck::tensor_layout::gemm::RowMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
bool pass = true;
const auto mx_mfma_kernel = ck::mx_mfma_test::
matmul<Src1Type, Src1VecSize, Src2Type, Src2VecSize, AccType, AccVecSize, DstType, M, N, K>;
pass = ck::mx_mfma_test::TestMXMFMA<decltype(mx_mfma_kernel),
Src1Type,
Src2Type,
DstType,
AccType,
CPUAccType,
decltype(Row{}),
decltype(Row{}),
decltype(Row{}),
PassThrough,
PassThrough,
PassThrough,
AccVecSize,
M,
N,
K>{}(mx_mfma_kernel);
return pass;
}
TEST(MXMFMA, FP8MFMA16x16x128)
{
auto pass = run_test<float, 1, float, 1, float, 1, float, float, 16, 16, 128>();
EXPECT_TRUE(pass);
}
// TEST(MXMFMA, FP8MFMA32x32x64)
// {
// EXPECT_TRUE(run_test<f8, 1, f8, 1, float, 1, float, float, 32, 32, 64>());
// }
// TEST(MXMFMA, BF8MFMA16x16x128)
// {
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 16, 16, 128>());
// }
// TEST(MXMFMA, BF8MFMA32x32x64)
// {
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 32, 32, 64>());
// }
TEST(MXMFMA, MXFP8xMXFP8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
TEST(MXMFMA, MXBF8xMXBF8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
#pragma once
#include "ck/ck.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
namespace ck {
namespace mx_mfma_test {
template <typename src_vec1, typename src_vec2, typename acc_vec>
__device__ void builtin_mx_mfma_naive_selector(const src_vec1&, const src_vec2&, acc_vec&)
{
}
// Smfmac instructions are using 4:2 structural sparsity, that means that in every contignuous
// subgroup of 4 elements, atleast 2 must be equal to zero and the position of non-zero elements is
// stored in idx register to allow selection of corresponding B matrix elements for multiplication.
// Currently smfmac instructions support only A matrix as sparse
template <typename src1_t,
index_t src1_vec_size,
typename src2_t,
index_t src2_vec_size,
typename acc_t,
index_t acc_vec_size,
typename dst_t,
int32_t M,
int32_t N,
int32_t K>
__global__ void matmul(const src1_t* a, const src2_t* b, dst_t* c)
{
__shared__ src1_t a_shared[M * K];
__shared__ src2_t b_shared[K * N];
const int lane = threadIdx.x;
// smfmac's A part is storing only non-zero elements in 2VGPRs
// smfmac's B part is storing all elements in 4VGPRs
using src1_vec = typename vector_type<src1_t, src1_vec_size>::type;
using src1_full_vec = typename vector_type<src1_t, src1_vec_size * 2>::type;
using src2_vec = typename vector_type<src2_t, src2_vec_size>::type;
src1_vec a_frag = {};
src2_vec b_frag = {};
src1_full_vec a_temp = {};
src2_vec b_temp = {};
// initialize c fragment to 0
using acc_vec = StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, acc_t, 1, acc_vec_size, true>;
acc_vec c_thread_buf_;
for(int i = 0; i < 8; ++i)
{
a_temp[i] = a[(lane % M) * K + (lane / M) * 8 + i]; // M K
}
for(int i = 0; i < 8; ++i)
{
b_temp[i] = b[(8 * (lane / N) + i) * N + (lane % N)]; // K N
}
__syncthreads();
for(int i = 0; i < 8; ++i)
{
a_shared[(lane % M) * K + (lane / M) * 8 + i] = a_temp[i];
}
for(int i = 0; i < 8; ++i)
{
b_shared[(8 * (lane / N) + i) * N + (lane % N)] = b_temp[i];
}
__syncthreads();
// Idx must be a 32-bit register and it is storing 4 2-bit indexes of A's non zero elements.
// It starts with last two elements of every 4 elements subgroup set as non-zero
int32_t idx = 0b11101110;
// Bit masks are for zeroing 0-3rd position of idx
static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000};
src1_t curr_val;
int32_t a_pos = 0;
for(int j = 0; j < 2; ++j)
{
a_pos = j * 2;
for(int i = 0; i < 4; ++i)
{
curr_val = a_shared[(lane % M) * K + (lane / M) * 8 + 4 * j + i];
if(curr_val != 0.0f)
{
idx &= ~bit_clear_masks[a_pos];
idx |= (i % 4) << 2 * a_pos;
a_frag[a_pos] = curr_val;
a_pos++;
}
}
}
for(int i = 0; i < 8; ++i)
{
b_frag[i] = b_shared[(8 * (lane / N) + i) * N + (lane % N)];
}
builtin_smfmac_naive_selector<src1_vec, src2_vec, acc_vec>(a_frag, b_frag, idx, c_thread_buf_);
__syncthreads();
// store results from unpacked c_thread_buf_ output
if constexpr(K == 32)
{
static_for<0, acc_vec_size, 1>{}([&](auto i) {
c[(4 * (lane / 16) + i) * N + lane % 16] =
ck::type_convert<dst_t>(c_thread_buf_[Number<i>{}]);
});
}
else
{
static_for<0, acc_vec_size, 1>{}([&](auto i) {
c[((8 * (i / 4)) % 32 + 4 * (lane / 32) + i % 4) * N + lane % 32] =
ck::type_convert<dst_t>(c_thread_buf_[Number<i>{}]);
});
}
}
/**
* @brief Structure to hold dimension parameters for GEMM tensors.
*
* M Number of rows in matrix A and matrix C.
* N Number of columns in matrix B and matrix C.
* K Number of columns in matrix A and number of rows in matrix B.
* StrideA Stride (leading dimension) of matrix A.
* StrideB Stride (leading dimension) of matrix B.
* StrideC Stride (leading dimension) of matrix C.
*/
struct GemmParams
{
/**
* @brief This constructor initializes the parameters for GEMM storage with default values.
*
* A[16x128] * B[128x16] = C[16x16], all row major.
*/
GemmParams() : M(16), N(16), K(128), StrideA(128), StrideB(16), StrideC(16) {}
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t StrideA;
ck::index_t StrideB;
ck::index_t StrideC;
};
template <typename GemmInstance,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void RunHostGEMM(const Tensor<ADataType>& A,
const Tensor<BDataType>& B,
Tensor<CDataType>& C,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
auto ref_gemm = GemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
}
template <typename KernelType, typename ADataType, typename BDataType, typename CDataType>
bool RunDeviceGEMM(KernelType kernel,
const Tensor<ADataType>& A,
const Tensor<BDataType>& B,
Tensor<CDataType>& C)
{
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(A.mData.data());
b_n_k_device_buf.ToDevice(B.mData.data());
kernel<<<1, 64>>>(static_cast<const ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<const BDataType*>(b_n_k_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()));
c_m_n_device_buf.FromDevice(C.mData.data());
return true;
}
template <typename DeviceMXMFMA,
typename ADataType,
typename BDataType,
typename CDataType,
typename GPUAccDataType,
typename CPUAccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t CAccNum,
index_t M,
index_t N,
index_t K>
struct TestMXMFMA
{
auto PrepareGemmTensors(const GemmParams& params)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_n_k(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result);
}
auto operator()(const DeviceMXMFMA& mfma_kernel)
{
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl;
// Arrange
GemmParams params;
params.M = M;
params.N = N;
params.K = K;
params.StrideA = K; // M K
params.StrideB = N; // K N
params.StrideC = N; // M N
auto host_tensors = PrepareGemmTensors(params);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{};
auto c_element_op = CElementwiseOperation{};
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
CPUAccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
RunHostGEMM<ReferenceGemmInstance>(a, b, c_host, a_element_op, b_element_op, c_element_op);
RunDeviceGEMM(mfma_kernel, a, b, c_device);
bool res = false;
if constexpr(std::is_same<CDataType, float>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else
{
std::cout << "UNSUPPORTED CDataType" << std::endl;
}
return res;
}
};
} // namespace mx_mfma_test
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment