Commit eb06c923 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Add tests for MFMA_F8F6F4::F32_16x16x128 and MFMA_F8F6F4::F32_32x32x64 instructions

parent a619e3f5
...@@ -530,7 +530,7 @@ endif() ...@@ -530,7 +530,7 @@ endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
add_compile_options(-fcolor-diagnostics) # add_compile_options(-fcolor-diagnostics)
endif() endif()
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9)
add_compile_options(-fdiagnostics-color=always) add_compile_options(-fdiagnostics-color=always)
......
{
"version": 3,
"configurePresets": [
{
"name": "linux-debug",
"displayName": "Linux Debug",
"hidden": true,
"generator": "Unix Makefiles",
"binaryDir": "${sourceDir}/build/${presetName}",
"installDir": "${sourceDir}/build/install/${presetName}",
"environment": {
"MY_ENVIRONMENT_VARIABLE": "NONE",
"PATH": "/usr/local/.cargo/bin:$penv{PATH}",
"SCCACHE_IDLE_TIMEOUT": "11000"
},
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
"BUILD_DEV": "ON",
"CMAKE_CXX_COMPILER": "/opt/rocm/bin/hipcc",
"CMAKE_PREFIX_PATH": "/opt/rocm",
"CMAKE_CXX_COMPILER_LAUNCHER": "sccache",
"CMAKE_C_COMPILER_LAUNCHER": "sccache"
},
"condition": {
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Linux"
}
},
{
"name": "MI355-debug",
"displayName": "MI355 Debug",
"inherits": "linux-debug",
"description": "Development Environment for MI355.",
"cacheVariables": {
"GPU_TARGETS": "gfx950",
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_CXX_FLAGS": "-O0 -ggdb"
}
},
{
"name": "MI355-release",
"displayName": "MI355 Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx950",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3"
}
},
{
"name": "MI300X-release",
"displayName": "MI300X Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx942",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3"
}
},
{
"name": "MI250-release",
"displayName": "MI250 Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx90a",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3",
"CK_USE_FP8_ON_UNSUPPORTED_ARCH":"ON"
}
},
{
"name": "MI250-debug",
"displayName": "MI250 Debug",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx90a",
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_CXX_FLAGS": "-O0 -ggdb",
"CK_USE_FP8_ON_UNSUPPORTED_ARCH":"ON"
}
},
{
"name": "RX7800-release",
"displayName": "RX7800 Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx1101",
"DL_KERNELS": "ON",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3"
}
},
{
"name": "RX7800-debug",
"displayName": "RX7800 Debug",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx1101",
"DL_KERNELS": "ON",
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_CXX_FLAGS": "-O0 -ggdb"
}
}
],
"buildPresets": [
{
"name": "Debug",
"hidden": true,
"configuration": "Debug"
},
{
"name": "Release",
"hidden": true,
"configuration": "Release"
},
{
"name": "MI355-debug",
"displayName": "MI355",
"configurePreset": "MI355-debug",
"description": "Build Environment for MI355 Debug.",
"inherits": [
"Debug"
],
"jobs": 128
},
{
"name": "MI355-release",
"displayName": "MI355",
"configurePreset": "MI355-release",
"description": "Build Environment for MI355 Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "MI300X-release",
"displayName": "MI300X",
"configurePreset": "MI300X-release",
"description": "Build Environment for MI300X Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "MI250-release",
"displayName": "MI250",
"configurePreset": "MI250-release",
"description": "Build Environment for MI250 Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "MI250-debug",
"displayName": "MI250",
"configurePreset": "MI250-debug",
"description": "Build Environment for MI250 Debug.",
"inherits": [
"Debug"
],
"jobs": 128
},
{
"name": "RX7800-release",
"displayName": "RX7800",
"configurePreset": "RX7800-release",
"description": "Build Environment for RX7800 Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "RX7800-debug",
"displayName": "RX7800",
"configurePreset": "RX7800-debug",
"description": "Build Environment for RX7800 Debug.",
"inherits": [
"Debug"
],
"jobs": 128
}
]
}
...@@ -6,52 +6,55 @@ ...@@ -6,52 +6,55 @@
#include "mx_mfma_op.hpp" #include "mx_mfma_op.hpp"
using ck::e8m0_bexp_t; using ck::e8m0_bexp_t;
using ck::f8_ocp_t; using ck::f8_t;
using ck::half_t;
using ck::type_convert; using ck::type_convert;
template <typename Src1Type, template <typename AType, typename BType, typename CType, ck::mx_mfma_test::MFMA_F8F6F4 mfma>
ck::index_t Src1VecSize,
typename Src2Type,
ck::index_t Src2VecSize,
typename DstType,
ck::index_t AccVecSize,
typename AccType,
typename CPUAccType,
ck::index_t M,
ck::index_t N,
ck::index_t K>
bool run_test() bool run_test()
{ {
using Row = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
bool pass = true; using CLayout = ck::tensor_layout::gemm::ColumnMajor;
using AccType = float; // only MFMA_F32 instructions supported
using CPUAccType = AccType;
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
const auto mx_mfma_kernel = ck::mx_mfma_test:: const auto mx_mfma_kernel =
matmul<Src1Type, Src1VecSize, Src2Type, Src2VecSize, AccType, AccVecSize, DstType, M, N, K>; ck::mx_mfma_test::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
bool pass = true;
pass = ck::mx_mfma_test::TestMXMFMA<decltype(mx_mfma_kernel), pass = ck::mx_mfma_test::TestMFMA<decltype(mx_mfma_kernel),
Src1Type, AType,
Src2Type, BType,
DstType, CType,
AccType, AccType,
CPUAccType, CPUAccType,
decltype(Row{}), ALayout,
decltype(Row{}), BLayout,
decltype(Row{}), CLayout,
PassThrough, BLOCK_M,
PassThrough, BLOCK_N,
PassThrough, BLOCK_K>{}(mx_mfma_kernel);
AccVecSize,
M,
N,
K>{}(mx_mfma_kernel);
return pass; return pass;
} }
TEST(MXMFMA, FP8MFMA16x16x128) TEST(MFMA, FP8MFMA16x16x128)
{
auto pass = run_test<f8_t, f8_t, half_t, ck::mx_mfma_test::MFMA_F8F6F4::F32_16x16x128>();
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
{ {
auto pass = run_test<float, 1, float, 1, float, 1, float, float, 16, 16, 128>(); auto pass = run_test<f8_t, f8_t, float, ck::mx_mfma_test::MFMA_F8F6F4::F32_32x32x64>();
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
...@@ -70,5 +73,5 @@ TEST(MXMFMA, FP8MFMA16x16x128) ...@@ -70,5 +73,5 @@ TEST(MXMFMA, FP8MFMA16x16x128)
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 32, 32, 64>()); // EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 32, 32, 64>());
// } // }
TEST(MXMFMA, MXFP8xMXFP8) { EXPECT_TRUE(false) << "Not Implemented\n"; } // TEST(MXMFMA, MXFP8xMXFP8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
TEST(MXMFMA, MXBF8xMXBF8) { EXPECT_TRUE(false) << "Not Implemented\n"; } // TEST(MXMFMA, MXBF8xMXBF8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -12,114 +13,332 @@ ...@@ -12,114 +13,332 @@
namespace ck { namespace ck {
namespace mx_mfma_test { namespace mx_mfma_test {
template <typename src_vec1, typename src_vec2, typename acc_vec> // MFMA instructions supported in this test
__device__ void builtin_mx_mfma_naive_selector(const src_vec1&, const src_vec2&, acc_vec&) enum class MFMA_F8F6F4
{ {
} F32_16x16x128 =
static_cast<int>(MfmaInstr::mfma_f32_16x16x128f8f6f4), // V_MFMA_F32_16X16X128_F8F6F4
F32_32x32x64 =
static_cast<int>(MfmaInstr::mfma_f32_32x32x64f8f6f4) // V_MFMA_F32_32X32X64_F8F6F4
};
template <typename AFragT, typename BFragT, typename AccumFragT, int32_t BLOCK_M, int32_t BLOCK_N>
struct mfma_type_selector;
// Smfmac instructions are using 4:2 structural sparsity, that means that in every contignuous template <typename AFragT, typename BFragT, typename AccumFragT>
// subgroup of 4 elements, atleast 2 must be equal to zero and the position of non-zero elements is struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
// stored in idx register to allow selection of corresponding B matrix elements for multiplication.
// Currently smfmac instructions support only A matrix as sparse
template <typename src1_t,
index_t src1_vec_size,
typename src2_t,
index_t src2_vec_size,
typename acc_t,
index_t acc_vec_size,
typename dst_t,
int32_t M,
int32_t N,
int32_t K>
__global__ void matmul(const src1_t* a, const src2_t* b, dst_t* c)
{ {
__shared__ src1_t a_shared[M * K]; __device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
__shared__ src2_t b_shared[K * N];
const int lane = threadIdx.x;
// smfmac's A part is storing only non-zero elements in 2VGPRs
// smfmac's B part is storing all elements in 4VGPRs
using src1_vec = typename vector_type<src1_t, src1_vec_size>::type;
using src1_full_vec = typename vector_type<src1_t, src1_vec_size * 2>::type;
using src2_vec = typename vector_type<src2_t, src2_vec_size>::type;
src1_vec a_frag = {};
src2_vec b_frag = {};
src1_full_vec a_temp = {};
src2_vec b_temp = {};
// initialize c fragment to 0
using acc_vec = StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, acc_t, 1, acc_vec_size, true>;
acc_vec c_thread_buf_;
for(int i = 0; i < 8; ++i)
{ {
a_temp[i] = a[(lane % M) * K + (lane / M) * 8 + i]; // M K #if 1
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
#else
ignore = fragA;
ignore = fragB;
ignore = fragAcc;
#endif
} }
};
for(int i = 0; i < 8; ++i) template <typename AFragT, typename BFragT, typename AccumFragT>
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
{
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{ {
b_temp[i] = b[(8 * (lane / N) + i) * N + (lane % N)]; // K N #if 1
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
#else
ignore = fragA;
ignore = fragB;
ignore = fragAcc;
#endif
} }
};
__syncthreads(); template <typename VecT>
static constexpr int32_t vectorSize(const VecT&)
{
return scalar_type<VecT>::vector_size;
}
for(int i = 0; i < 8; ++i) // Define a load function for input A blocks:
{ // Size: (BLOCK_M x BLOCK_K)
a_shared[(lane % M) * K + (lane / M) * 8 + i] = a_temp[i]; // ASSUMPTION:
} // - We want contiguous BLOCK_M sized column neighbors in register.
for(int i = 0; i < 8; ++i) // - Data is in col_major format
{ // This means:
b_shared[(8 * (lane / N) + i) * N + (lane % N)] = b_temp[i]; // - From A we will load K columns of size BLOCK_M to satisfy our input data
} template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
__device__ AFragT load_A_col_major(AType const* input_ptr)
{
// Here we want to load a BLOCK_M x BLOCK_K block of data.
static constexpr uint32_t VW = vectorSize(AFragT{});
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT = vector_type<ARawT, VW>::type;
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
(threadIdx.x / BLOCK_M) * VW); // Col
auto stepCoord2D = std::make_pair(0u, 1u);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
// BLOCK_M is a stride in A matrix
auto startOffset = col_major(startCoord2D, BLOCK_M);
auto kOffset = col_major(stepCoord2D, BLOCK_M);
// kOffset == BLOCK_M
// This means every BLOCK_M element is loaded into output vector
auto fragA = AScalarFragT{
bit_cast<ARawT>(input_ptr[startOffset]), // XXX v[0] = Reg 0 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 1 * kOffset]), // XXX v[1] = Reg 0 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 2 * kOffset]), // XXX v[2] = Reg 0 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 3 * kOffset]), // XXX v[3] = Reg 0 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 4 * kOffset]), // XXX v[4] = Reg 1 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 5 * kOffset]), // XXX v[5] = Reg 1 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 6 * kOffset]), // XXX v[6] = Reg 1 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 7 * kOffset]), // XXX v[7] = Reg 1 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 8 * kOffset]), // XXX v[8] = Reg 2 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 9 * kOffset]), // XXX v[9] = Reg 2 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 10 * kOffset]), // XXX v[10] = Reg 2 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 11 * kOffset]), // XXX v[11] = Reg 2 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 12 * kOffset]), // XXX v[12] = Reg 3 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 13 * kOffset]), // XXX v[13] = Reg 3 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 14 * kOffset]), // XXX v[14] = Reg 3 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 15 * kOffset]), // XXX v[15] = Reg 3 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 16 * kOffset]), // XXX v[16] = Reg 4 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 17 * kOffset]), // XXX v[17] = Reg 4 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 18 * kOffset]), // XXX v[18] = Reg 4 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 19 * kOffset]), // XXX v[19] = Reg 4 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 20 * kOffset]), // XXX v[20] = Reg 5 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 21 * kOffset]), // XXX v[21] = Reg 5 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 22 * kOffset]), // XXX v[22] = Reg 5 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 23 * kOffset]), // XXX v[23] = Reg 5 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 24 * kOffset]), // XXX v[24] = Reg 6 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 25 * kOffset]), // XXX v[25] = Reg 6 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 26 * kOffset]), // XXX v[26] = Reg 6 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 27 * kOffset]), // XXX v[27] = Reg 6 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 28 * kOffset]), // XXX v[28] = Reg 7 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 29 * kOffset]), // XXX v[29] = Reg 7 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 30 * kOffset]), // XXX v[30] = Reg 7 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 31 * kOffset])}; // XXX v[31] = Reg 7 [24:31]
return fragA;
}
__syncthreads(); // Define a load function for input B blocks:
// Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in row_major format
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N>
__device__ BFragT load_B_col_major(BType const* input_ptr)
{
// Here we want to load a BLOCK_K x BLOCK_N block of data.
static constexpr uint32_t VW = vectorSize(BFragT{});
// Idx must be a 32-bit register and it is storing 4 2-bit indexes of A's non zero elements. // To start the loading process, let's visualize in 2D coords.
// It starts with last two elements of every 4 elements subgroup set as non-zero // Each thread will load 32 elements.
int32_t idx = 0b11101110; // We need to know where they start, and where the next elements are.
// Bit masks are for zeroing 0-3rd position of idx auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row
static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000}; threadIdx.x % BLOCK_N); // Col
// auto stepCoord2D = std::make_pair(1u, 0u);
src1_t curr_val; // Flatten to 1D col_major offsets.
int32_t a_pos = 0; auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
for(int j = 0; j < 2; ++j)
{ auto startOffset = col_major(startCoord2D, BLOCK_K);
a_pos = j * 2; // auto kOffset = col_major(stepCoord2D, BLOCK_K);
for(int i = 0; i < 4; ++i)
// kOffset == 1
auto const* fragPtr = reinterpret_cast<BFragT const*>(input_ptr + startOffset);
return *fragPtr;
}
// Define a store function for C
// Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in col_major format
// This means:
// - From C we will load BLOCK_M rows of size BLOCK_N to satisfy our input data
template <typename CType, typename CFragT, int32_t BLOCK_M, int32_t BLOCK_N>
struct store_C_col_major;
// Here we want to store a 16x16 block of data.
//
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | Vector
// Register Element | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element
// ____________ _____________ _____________ ______________
// Reg0 | M0 | M4 | M8 | M12 | v[0]
// Reg1 | M1 | M5 | M9 | M13 | v[1]
// Reg2 | M2 | M6 | M10 | M14 | v[2]
// Reg3 | M3 | M7 | M11 | M15 | v[3]
template <typename CType, typename CFragT>
struct store_C_col_major<CType, CFragT, 16, 16>
{
__device__ void operator()(CType* output, CFragT cFrag)
{ {
curr_val = a_shared[(lane % M) * K + (lane / M) * 8 + 4 * j + i]; static constexpr uint32_t VW = vectorSize(cFrag); // 4
if(curr_val != 0.0f) static constexpr uint32_t Dim = 16;
#if 1
for(int i = 0; i < vectorSize(cFrag); ++i)
{ {
idx &= ~bit_clear_masks[a_pos]; printf("threadIdx.x = %d; cFrag[%d] = %f\n",
idx |= (i % 4) << 2 * a_pos; static_cast<int>(threadIdx.x),
a_frag[a_pos] = curr_val; i,
a_pos++; static_cast<float>(cFrag[i]));
}
} }
#endif
// Each thread will load 4 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
// auto stepCoord2D = std::make_pair(1u, 0u);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, 16);
// auto kOffset = col_major(stepCoord2D, 16); // 1
// kOffset == 1
auto* fragPtr = reinterpret_cast<CFragT*>(output + startOffset);
*fragPtr = cFrag;
} }
};
for(int i = 0; i < 8; ++i) // Here we want to store a 32x32 block of data.
// Register Mapping:
// Size | BLOCK_N | BLOCK_N | Vector
// Register Element | 0 ... 31 | 32 ... 63 | Element
// ____________ _____________
// Reg0 | M0 | M4 | v[0]
// Reg1 | M1 | M5 | v[1]
// Reg2 | M2 | M6 | v[2]
// Reg3 | M3 | M7 | v[3]
// ____________ _____________
// Reg4 | M8 | M12 | v[4]
// Reg5 | M9 | M13 | v[5]
// Reg6 | M10 | M14 | v[6]
// Reg7 | M11 | M15 | v[7]
// ____________ _____________
// Reg8 | M16 | M20 | v[8]
// Reg9 | M17 | M21 | v[9]
// Reg10 | M18 | M22 | v[10]
// Reg11 | M19 | M23 | v[11]
// ____________ _____________
// Reg12 | M24 | M28 | v[12]
// Reg13 | M25 | M29 | v[13]
// Reg14 | M26 | M30 | v[14]
// Reg15 | M27 | M31 | v[15]
template <typename CType, typename CFragT>
struct store_C_col_major<CType, CFragT, 32, 32>
{
__device__ void operator()(CType* output, CFragT cFrag)
{
static constexpr uint32_t WAVE_SIZE = 64;
static constexpr uint32_t VW = 4;
static constexpr uint32_t Dim = 32;
static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8
#if 1
for(int i = 0; i < vectorSize(cFrag); ++i)
{ {
b_frag[i] = b_shared[(8 * (lane / N) + i) * N + (lane % N)]; printf("threadIdx.x = %d; cFrag[%d] = %f\n",
static_cast<int>(threadIdx.x),
i,
static_cast<float>(cFrag[i]));
} }
#endif
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
// Minor step for each 'chunk'
// auto minorStepCoord2D = std::make_pair(1u, 0u);
builtin_smfmac_naive_selector<src1_vec, src2_vec, acc_vec>(a_frag, b_frag, idx, c_thread_buf_); // Major step between 'chunks'
__syncthreads(); auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0);
// store results from unpacked c_thread_buf_ output // Flatten to 1D col_major offsets.
if constexpr(K == 32) auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, 32);
// auto kMinorOffset = col_major(minorStepCoord2D, 32); // 1
auto kMajorOffset = col_major(majorStepCoord2D, 32); // 8
// kMinorOffset == 1.
// This means we can vector store 4 contiguous elements at a time.
using CRawT = typename scalar_type<CFragT>::type;
using CScalarFragT = vector_type<CRawT, VW>::type;
union
{ {
static_for<0, acc_vec_size, 1>{}([&](auto i) { CFragT frag;
c[(4 * (lane / 16) + i) * N + lane % 16] = CScalarFragT chunks[vectorSize(CFragT{}) / VW];
ck::type_convert<dst_t>(c_thread_buf_[Number<i>{}]); } fragC{cFrag}; // Initialize with input fragment
});
*(reinterpret_cast<CScalarFragT*>(output + startOffset)) = fragC.chunks[0];
*(reinterpret_cast<CScalarFragT*>(output + startOffset + kMajorOffset)) = fragC.chunks[1];
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 2 * kMajorOffset)) =
fragC.chunks[2];
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 3 * kMajorOffset)) =
fragC.chunks[3];
} }
else };
template <typename AType,
typename BType,
typename CType,
typename AccType,
int32_t BLOCK_M,
int32_t BLOCK_N,
int32_t BLOCK_K>
__global__ void matmul(const AType* a, const BType* b, CType* c)
{
constexpr int WAVE_SIZE = 64;
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
// Create frags
auto fragA = AFragT{};
auto fragB = BFragT{};
auto fragC = CFragT{};
auto fragAcc = AccumFragT{0};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
// B = col major, BLOCK_K x BLOCK_N
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
// Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);
for(int i = 0; i < vectorSize(fragC); ++i)
{ {
static_for<0, acc_vec_size, 1>{}([&](auto i) { fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
c[((8 * (i / 4)) % 32 + 4 * (lane / 32) + i % 4) * N + lane % 32] =
ck::type_convert<dst_t>(c_thread_buf_[Number<i>{}]);
});
} }
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC);
} }
/** /**
...@@ -191,7 +410,7 @@ bool RunDeviceGEMM(KernelType kernel, ...@@ -191,7 +410,7 @@ bool RunDeviceGEMM(KernelType kernel,
return true; return true;
} }
template <typename DeviceMXMFMA, template <typename DeviceMFMA,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
...@@ -200,14 +419,10 @@ template <typename DeviceMXMFMA, ...@@ -200,14 +419,10 @@ template <typename DeviceMXMFMA,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename AElementwiseOperation, index_t BLOCK_M,
typename BElementwiseOperation, index_t BLOCK_N,
typename CElementwiseOperation, index_t BLOCK_K>
index_t CAccNum, struct TestMFMA
index_t M,
index_t N,
index_t K>
struct TestMXMFMA
{ {
auto PrepareGemmTensors(const GemmParams& params) auto PrepareGemmTensors(const GemmParams& params)
{ {
...@@ -234,25 +449,25 @@ struct TestMXMFMA ...@@ -234,25 +449,25 @@ struct TestMXMFMA
Tensor<CDataType> c_m_n_device_result( Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1}); b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result); return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result);
} }
auto operator()(const DeviceMXMFMA& mfma_kernel) auto operator()(const DeviceMFMA& mfma_kernel)
{ {
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl; << ", CLayout = " << CLayout{}.name << std::endl;
// Arrange // Arrange
GemmParams params; GemmParams params;
params.M = M; params.M = BLOCK_M;
params.N = N; params.N = BLOCK_N;
params.K = K; params.K = BLOCK_K;
params.StrideA = K; // M K params.StrideA = BLOCK_K; // M K
params.StrideB = N; // K N params.StrideB = BLOCK_N; // K N
params.StrideC = N; // M N params.StrideC = BLOCK_N; // M N
auto host_tensors = PrepareGemmTensors(params); auto host_tensors = PrepareGemmTensors(params);
...@@ -261,25 +476,27 @@ struct TestMXMFMA ...@@ -261,25 +476,27 @@ struct TestMXMFMA
Tensor<CDataType>& c_host = std::get<2>(host_tensors); Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors); Tensor<CDataType>& c_device = std::get<3>(host_tensors);
auto a_element_op = AElementwiseOperation{}; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
auto b_element_op = BElementwiseOperation{};
auto c_element_op = CElementwiseOperation{}; auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
using ReferenceGemmInstance = using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
CPUAccDataType, CPUAccDataType,
AElementwiseOperation, PassThrough,
BElementwiseOperation, PassThrough,
CElementwiseOperation>; PassThrough>;
RunHostGEMM<ReferenceGemmInstance>(a, b, c_host, a_element_op, b_element_op, c_element_op); RunHostGEMM<ReferenceGemmInstance>(a, b, c_host, a_element_op, b_element_op, c_element_op);
RunDeviceGEMM(mfma_kernel, a, b, c_device); RunDeviceGEMM(mfma_kernel, a, b, c_device);
bool res = false; bool res = false;
if constexpr(std::is_same<CDataType, float>::value) if constexpr(std::is_same<CDataType, float>::value ||
std::is_same<CDataType, half_t>::value)
{ {
res = ck::utils::check_err(c_device.mData, c_host.mData); res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment