Commit eb06c923 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Add tests for MFMA_F8F6F4::F32_16x16x128 and MFMA_F8F6F4::F32_32x32x64 instructions

parent a619e3f5
......@@ -530,7 +530,7 @@ endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
add_compile_options(-fcolor-diagnostics)
# add_compile_options(-fcolor-diagnostics)
endif()
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9)
add_compile_options(-fdiagnostics-color=always)
......
{
"version": 3,
"configurePresets": [
{
"name": "linux-debug",
"displayName": "Linux Debug",
"hidden": true,
"generator": "Unix Makefiles",
"binaryDir": "${sourceDir}/build/${presetName}",
"installDir": "${sourceDir}/build/install/${presetName}",
"environment": {
"MY_ENVIRONMENT_VARIABLE": "NONE",
"PATH": "/usr/local/.cargo/bin:$penv{PATH}",
"SCCACHE_IDLE_TIMEOUT": "11000"
},
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
"BUILD_DEV": "ON",
"CMAKE_CXX_COMPILER": "/opt/rocm/bin/hipcc",
"CMAKE_PREFIX_PATH": "/opt/rocm",
"CMAKE_CXX_COMPILER_LAUNCHER": "sccache",
"CMAKE_C_COMPILER_LAUNCHER": "sccache"
},
"condition": {
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Linux"
}
},
{
"name": "MI355-debug",
"displayName": "MI355 Debug",
"inherits": "linux-debug",
"description": "Development Environment for MI355.",
"cacheVariables": {
"GPU_TARGETS": "gfx950",
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_CXX_FLAGS": "-O0 -ggdb"
}
},
{
"name": "MI355-release",
"displayName": "MI355 Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx950",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3"
}
},
{
"name": "MI300X-release",
"displayName": "MI300X Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx942",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3"
}
},
{
"name": "MI250-release",
"displayName": "MI250 Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx90a",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3",
"CK_USE_FP8_ON_UNSUPPORTED_ARCH":"ON"
}
},
{
"name": "MI250-debug",
"displayName": "MI250 Debug",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx90a",
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_CXX_FLAGS": "-O0 -ggdb",
"CK_USE_FP8_ON_UNSUPPORTED_ARCH":"ON"
}
},
{
"name": "RX7800-release",
"displayName": "RX7800 Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx1101",
"DL_KERNELS": "ON",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3"
}
},
{
"name": "RX7800-debug",
"displayName": "RX7800 Debug",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx1101",
"DL_KERNELS": "ON",
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_CXX_FLAGS": "-O0 -ggdb"
}
}
],
"buildPresets": [
{
"name": "Debug",
"hidden": true,
"configuration": "Debug"
},
{
"name": "Release",
"hidden": true,
"configuration": "Release"
},
{
"name": "MI355-debug",
"displayName": "MI355",
"configurePreset": "MI355-debug",
"description": "Build Environment for MI355 Debug.",
"inherits": [
"Debug"
],
"jobs": 128
},
{
"name": "MI355-release",
"displayName": "MI355",
"configurePreset": "MI355-release",
"description": "Build Environment for MI355 Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "MI300X-release",
"displayName": "MI300X",
"configurePreset": "MI300X-release",
"description": "Build Environment for MI300X Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "MI250-release",
"displayName": "MI250",
"configurePreset": "MI250-release",
"description": "Build Environment for MI250 Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "MI250-debug",
"displayName": "MI250",
"configurePreset": "MI250-debug",
"description": "Build Environment for MI250 Debug.",
"inherits": [
"Debug"
],
"jobs": 128
},
{
"name": "RX7800-release",
"displayName": "RX7800",
"configurePreset": "RX7800-release",
"description": "Build Environment for RX7800 Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "RX7800-debug",
"displayName": "RX7800",
"configurePreset": "RX7800-debug",
"description": "Build Environment for RX7800 Debug.",
"inherits": [
"Debug"
],
"jobs": 128
}
]
}
......@@ -6,52 +6,55 @@
#include "mx_mfma_op.hpp"
using ck::e8m0_bexp_t;
using ck::f8_ocp_t;
using ck::f8_t;
using ck::half_t;
using ck::type_convert;
template <typename Src1Type,
ck::index_t Src1VecSize,
typename Src2Type,
ck::index_t Src2VecSize,
typename DstType,
ck::index_t AccVecSize,
typename AccType,
typename CPUAccType,
ck::index_t M,
ck::index_t N,
ck::index_t K>
template <typename AType, typename BType, typename CType, ck::mx_mfma_test::MFMA_F8F6F4 mfma>
bool run_test()
{
using Row = ck::tensor_layout::gemm::RowMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
bool pass = true;
const auto mx_mfma_kernel = ck::mx_mfma_test::
matmul<Src1Type, Src1VecSize, Src2Type, Src2VecSize, AccType, AccVecSize, DstType, M, N, K>;
pass = ck::mx_mfma_test::TestMXMFMA<decltype(mx_mfma_kernel),
Src1Type,
Src2Type,
DstType,
AccType,
CPUAccType,
decltype(Row{}),
decltype(Row{}),
decltype(Row{}),
PassThrough,
PassThrough,
PassThrough,
AccVecSize,
M,
N,
K>{}(mx_mfma_kernel);
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
using AccType = float; // only MFMA_F32 instructions supported
using CPUAccType = AccType;
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
const auto mx_mfma_kernel =
ck::mx_mfma_test::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
bool pass = true;
pass = ck::mx_mfma_test::TestMFMA<decltype(mx_mfma_kernel),
AType,
BType,
CType,
AccType,
CPUAccType,
ALayout,
BLayout,
CLayout,
BLOCK_M,
BLOCK_N,
BLOCK_K>{}(mx_mfma_kernel);
return pass;
}
TEST(MXMFMA, FP8MFMA16x16x128)
TEST(MFMA, FP8MFMA16x16x128)
{
auto pass = run_test<f8_t, f8_t, half_t, ck::mx_mfma_test::MFMA_F8F6F4::F32_16x16x128>();
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
{
auto pass = run_test<float, 1, float, 1, float, 1, float, float, 16, 16, 128>();
auto pass = run_test<f8_t, f8_t, float, ck::mx_mfma_test::MFMA_F8F6F4::F32_32x32x64>();
EXPECT_TRUE(pass);
}
......@@ -70,5 +73,5 @@ TEST(MXMFMA, FP8MFMA16x16x128)
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 32, 32, 64>());
// }
TEST(MXMFMA, MXFP8xMXFP8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
TEST(MXMFMA, MXBF8xMXBF8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
// TEST(MXMFMA, MXFP8xMXFP8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
// TEST(MXMFMA, MXBF8xMXBF8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
......@@ -5,6 +5,7 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -12,114 +13,332 @@
namespace ck {
namespace mx_mfma_test {
template <typename src_vec1, typename src_vec2, typename acc_vec>
__device__ void builtin_mx_mfma_naive_selector(const src_vec1&, const src_vec2&, acc_vec&)
// MFMA instructions supported in this test
enum class MFMA_F8F6F4
{
}
F32_16x16x128 =
static_cast<int>(MfmaInstr::mfma_f32_16x16x128f8f6f4), // V_MFMA_F32_16X16X128_F8F6F4
F32_32x32x64 =
static_cast<int>(MfmaInstr::mfma_f32_32x32x64f8f6f4) // V_MFMA_F32_32X32X64_F8F6F4
};
template <typename AFragT, typename BFragT, typename AccumFragT, int32_t BLOCK_M, int32_t BLOCK_N>
struct mfma_type_selector;
// Smfmac instructions are using 4:2 structural sparsity, that means that in every contignuous
// subgroup of 4 elements, atleast 2 must be equal to zero and the position of non-zero elements is
// stored in idx register to allow selection of corresponding B matrix elements for multiplication.
// Currently smfmac instructions support only A matrix as sparse
template <typename src1_t,
index_t src1_vec_size,
typename src2_t,
index_t src2_vec_size,
typename acc_t,
index_t acc_vec_size,
typename dst_t,
int32_t M,
int32_t N,
int32_t K>
__global__ void matmul(const src1_t* a, const src2_t* b, dst_t* c)
template <typename AFragT, typename BFragT, typename AccumFragT>
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
{
__shared__ src1_t a_shared[M * K];
__shared__ src2_t b_shared[K * N];
const int lane = threadIdx.x;
// smfmac's A part is storing only non-zero elements in 2VGPRs
// smfmac's B part is storing all elements in 4VGPRs
using src1_vec = typename vector_type<src1_t, src1_vec_size>::type;
using src1_full_vec = typename vector_type<src1_t, src1_vec_size * 2>::type;
using src2_vec = typename vector_type<src2_t, src2_vec_size>::type;
src1_vec a_frag = {};
src2_vec b_frag = {};
src1_full_vec a_temp = {};
src2_vec b_temp = {};
// initialize c fragment to 0
using acc_vec = StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, acc_t, 1, acc_vec_size, true>;
acc_vec c_thread_buf_;
for(int i = 0; i < 8; ++i)
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{
a_temp[i] = a[(lane % M) * K + (lane / M) * 8 + i]; // M K
#if 1
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
#else
ignore = fragA;
ignore = fragB;
ignore = fragAcc;
#endif
}
};
for(int i = 0; i < 8; ++i)
template <typename AFragT, typename BFragT, typename AccumFragT>
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
{
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{
b_temp[i] = b[(8 * (lane / N) + i) * N + (lane % N)]; // K N
#if 1
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
#else
ignore = fragA;
ignore = fragB;
ignore = fragAcc;
#endif
}
};
__syncthreads();
template <typename VecT>
static constexpr int32_t vectorSize(const VecT&)
{
return scalar_type<VecT>::vector_size;
}
for(int i = 0; i < 8; ++i)
{
a_shared[(lane % M) * K + (lane / M) * 8 + i] = a_temp[i];
}
for(int i = 0; i < 8; ++i)
{
b_shared[(8 * (lane / N) + i) * N + (lane % N)] = b_temp[i];
}
// Define a load function for input A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in col_major format
// This means:
// - From A we will load K columns of size BLOCK_M to satisfy our input data
template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
__device__ AFragT load_A_col_major(AType const* input_ptr)
{
// Here we want to load a BLOCK_M x BLOCK_K block of data.
static constexpr uint32_t VW = vectorSize(AFragT{});
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT = vector_type<ARawT, VW>::type;
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
(threadIdx.x / BLOCK_M) * VW); // Col
auto stepCoord2D = std::make_pair(0u, 1u);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
// BLOCK_M is a stride in A matrix
auto startOffset = col_major(startCoord2D, BLOCK_M);
auto kOffset = col_major(stepCoord2D, BLOCK_M);
// kOffset == BLOCK_M
// This means every BLOCK_M element is loaded into output vector
auto fragA = AScalarFragT{
bit_cast<ARawT>(input_ptr[startOffset]), // XXX v[0] = Reg 0 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 1 * kOffset]), // XXX v[1] = Reg 0 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 2 * kOffset]), // XXX v[2] = Reg 0 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 3 * kOffset]), // XXX v[3] = Reg 0 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 4 * kOffset]), // XXX v[4] = Reg 1 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 5 * kOffset]), // XXX v[5] = Reg 1 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 6 * kOffset]), // XXX v[6] = Reg 1 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 7 * kOffset]), // XXX v[7] = Reg 1 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 8 * kOffset]), // XXX v[8] = Reg 2 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 9 * kOffset]), // XXX v[9] = Reg 2 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 10 * kOffset]), // XXX v[10] = Reg 2 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 11 * kOffset]), // XXX v[11] = Reg 2 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 12 * kOffset]), // XXX v[12] = Reg 3 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 13 * kOffset]), // XXX v[13] = Reg 3 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 14 * kOffset]), // XXX v[14] = Reg 3 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 15 * kOffset]), // XXX v[15] = Reg 3 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 16 * kOffset]), // XXX v[16] = Reg 4 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 17 * kOffset]), // XXX v[17] = Reg 4 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 18 * kOffset]), // XXX v[18] = Reg 4 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 19 * kOffset]), // XXX v[19] = Reg 4 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 20 * kOffset]), // XXX v[20] = Reg 5 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 21 * kOffset]), // XXX v[21] = Reg 5 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 22 * kOffset]), // XXX v[22] = Reg 5 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 23 * kOffset]), // XXX v[23] = Reg 5 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 24 * kOffset]), // XXX v[24] = Reg 6 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 25 * kOffset]), // XXX v[25] = Reg 6 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 26 * kOffset]), // XXX v[26] = Reg 6 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 27 * kOffset]), // XXX v[27] = Reg 6 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 28 * kOffset]), // XXX v[28] = Reg 7 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 29 * kOffset]), // XXX v[29] = Reg 7 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 30 * kOffset]), // XXX v[30] = Reg 7 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 31 * kOffset])}; // XXX v[31] = Reg 7 [24:31]
return fragA;
}
// Define a load function for input B blocks:
// Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in row_major format
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N>
__device__ BFragT load_B_col_major(BType const* input_ptr)
{
// Here we want to load a BLOCK_K x BLOCK_N block of data.
static constexpr uint32_t VW = vectorSize(BFragT{});
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row
threadIdx.x % BLOCK_N); // Col
// auto stepCoord2D = std::make_pair(1u, 0u);
__syncthreads();
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
// Idx must be a 32-bit register and it is storing 4 2-bit indexes of A's non zero elements.
// It starts with last two elements of every 4 elements subgroup set as non-zero
int32_t idx = 0b11101110;
// Bit masks are for zeroing 0-3rd position of idx
static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000};
auto startOffset = col_major(startCoord2D, BLOCK_K);
// auto kOffset = col_major(stepCoord2D, BLOCK_K);
src1_t curr_val;
int32_t a_pos = 0;
for(int j = 0; j < 2; ++j)
// kOffset == 1
auto const* fragPtr = reinterpret_cast<BFragT const*>(input_ptr + startOffset);
return *fragPtr;
}
// Define a store function for C
// Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in col_major format
// This means:
// - From C we will load BLOCK_M rows of size BLOCK_N to satisfy our input data
template <typename CType, typename CFragT, int32_t BLOCK_M, int32_t BLOCK_N>
struct store_C_col_major;
// Here we want to store a 16x16 block of data.
//
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | Vector
// Register Element | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element
// ____________ _____________ _____________ ______________
// Reg0 | M0 | M4 | M8 | M12 | v[0]
// Reg1 | M1 | M5 | M9 | M13 | v[1]
// Reg2 | M2 | M6 | M10 | M14 | v[2]
// Reg3 | M3 | M7 | M11 | M15 | v[3]
template <typename CType, typename CFragT>
struct store_C_col_major<CType, CFragT, 16, 16>
{
__device__ void operator()(CType* output, CFragT cFrag)
{
a_pos = j * 2;
for(int i = 0; i < 4; ++i)
static constexpr uint32_t VW = vectorSize(cFrag); // 4
static constexpr uint32_t Dim = 16;
#if 1
for(int i = 0; i < vectorSize(cFrag); ++i)
{
curr_val = a_shared[(lane % M) * K + (lane / M) * 8 + 4 * j + i];
if(curr_val != 0.0f)
{
idx &= ~bit_clear_masks[a_pos];
idx |= (i % 4) << 2 * a_pos;
a_frag[a_pos] = curr_val;
a_pos++;
}
printf("threadIdx.x = %d; cFrag[%d] = %f\n",
static_cast<int>(threadIdx.x),
i,
static_cast<float>(cFrag[i]));
}
#endif
// Each thread will load 4 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
// auto stepCoord2D = std::make_pair(1u, 0u);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, 16);
// auto kOffset = col_major(stepCoord2D, 16); // 1
// kOffset == 1
auto* fragPtr = reinterpret_cast<CFragT*>(output + startOffset);
*fragPtr = cFrag;
}
};
for(int i = 0; i < 8; ++i)
// Here we want to store a 32x32 block of data.
// Register Mapping:
// Size | BLOCK_N | BLOCK_N | Vector
// Register Element | 0 ... 31 | 32 ... 63 | Element
// ____________ _____________
// Reg0 | M0 | M4 | v[0]
// Reg1 | M1 | M5 | v[1]
// Reg2 | M2 | M6 | v[2]
// Reg3 | M3 | M7 | v[3]
// ____________ _____________
// Reg4 | M8 | M12 | v[4]
// Reg5 | M9 | M13 | v[5]
// Reg6 | M10 | M14 | v[6]
// Reg7 | M11 | M15 | v[7]
// ____________ _____________
// Reg8 | M16 | M20 | v[8]
// Reg9 | M17 | M21 | v[9]
// Reg10 | M18 | M22 | v[10]
// Reg11 | M19 | M23 | v[11]
// ____________ _____________
// Reg12 | M24 | M28 | v[12]
// Reg13 | M25 | M29 | v[13]
// Reg14 | M26 | M30 | v[14]
// Reg15 | M27 | M31 | v[15]
template <typename CType, typename CFragT>
struct store_C_col_major<CType, CFragT, 32, 32>
{
__device__ void operator()(CType* output, CFragT cFrag)
{
b_frag[i] = b_shared[(8 * (lane / N) + i) * N + (lane % N)];
}
static constexpr uint32_t WAVE_SIZE = 64;
static constexpr uint32_t VW = 4;
static constexpr uint32_t Dim = 32;
static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8
builtin_smfmac_naive_selector<src1_vec, src2_vec, acc_vec>(a_frag, b_frag, idx, c_thread_buf_);
__syncthreads();
#if 1
for(int i = 0; i < vectorSize(cFrag); ++i)
{
printf("threadIdx.x = %d; cFrag[%d] = %f\n",
static_cast<int>(threadIdx.x),
i,
static_cast<float>(cFrag[i]));
}
#endif
// store results from unpacked c_thread_buf_ output
if constexpr(K == 32)
{
static_for<0, acc_vec_size, 1>{}([&](auto i) {
c[(4 * (lane / 16) + i) * N + lane % 16] =
ck::type_convert<dst_t>(c_thread_buf_[Number<i>{}]);
});
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
// Minor step for each 'chunk'
// auto minorStepCoord2D = std::make_pair(1u, 0u);
// Major step between 'chunks'
auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, 32);
// auto kMinorOffset = col_major(minorStepCoord2D, 32); // 1
auto kMajorOffset = col_major(majorStepCoord2D, 32); // 8
// kMinorOffset == 1.
// This means we can vector store 4 contiguous elements at a time.
using CRawT = typename scalar_type<CFragT>::type;
using CScalarFragT = vector_type<CRawT, VW>::type;
union
{
CFragT frag;
CScalarFragT chunks[vectorSize(CFragT{}) / VW];
} fragC{cFrag}; // Initialize with input fragment
*(reinterpret_cast<CScalarFragT*>(output + startOffset)) = fragC.chunks[0];
*(reinterpret_cast<CScalarFragT*>(output + startOffset + kMajorOffset)) = fragC.chunks[1];
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 2 * kMajorOffset)) =
fragC.chunks[2];
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 3 * kMajorOffset)) =
fragC.chunks[3];
}
else
};
template <typename AType,
typename BType,
typename CType,
typename AccType,
int32_t BLOCK_M,
int32_t BLOCK_N,
int32_t BLOCK_K>
__global__ void matmul(const AType* a, const BType* b, CType* c)
{
constexpr int WAVE_SIZE = 64;
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
// Create frags
auto fragA = AFragT{};
auto fragB = BFragT{};
auto fragC = CFragT{};
auto fragAcc = AccumFragT{0};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
// B = col major, BLOCK_K x BLOCK_N
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
// Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);
for(int i = 0; i < vectorSize(fragC); ++i)
{
static_for<0, acc_vec_size, 1>{}([&](auto i) {
c[((8 * (i / 4)) % 32 + 4 * (lane / 32) + i % 4) * N + lane % 32] =
ck::type_convert<dst_t>(c_thread_buf_[Number<i>{}]);
});
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
}
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC);
}
/**
......@@ -191,7 +410,7 @@ bool RunDeviceGEMM(KernelType kernel,
return true;
}
template <typename DeviceMXMFMA,
template <typename DeviceMFMA,
typename ADataType,
typename BDataType,
typename CDataType,
......@@ -200,14 +419,10 @@ template <typename DeviceMXMFMA,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t CAccNum,
index_t M,
index_t N,
index_t K>
struct TestMXMFMA
index_t BLOCK_M,
index_t BLOCK_N,
index_t BLOCK_K>
struct TestMFMA
{
auto PrepareGemmTensors(const GemmParams& params)
{
......@@ -234,25 +449,25 @@ struct TestMXMFMA
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result);
}
auto operator()(const DeviceMXMFMA& mfma_kernel)
auto operator()(const DeviceMFMA& mfma_kernel)
{
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl;
// Arrange
GemmParams params;
params.M = M;
params.N = N;
params.K = K;
params.StrideA = K; // M K
params.StrideB = N; // K N
params.StrideC = N; // M N
params.M = BLOCK_M;
params.N = BLOCK_N;
params.K = BLOCK_K;
params.StrideA = BLOCK_K; // M K
params.StrideB = BLOCK_N; // K N
params.StrideC = BLOCK_N; // M N
auto host_tensors = PrepareGemmTensors(params);
......@@ -261,25 +476,27 @@ struct TestMXMFMA
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{};
auto c_element_op = CElementwiseOperation{};
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
CPUAccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
CPUAccDataType,
PassThrough,
PassThrough,
PassThrough>;
RunHostGEMM<ReferenceGemmInstance>(a, b, c_host, a_element_op, b_element_op, c_element_op);
RunDeviceGEMM(mfma_kernel, a, b, c_device);
bool res = false;
if constexpr(std::is_same<CDataType, float>::value)
if constexpr(std::is_same<CDataType, float>::value ||
std::is_same<CDataType, half_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment