Commit f3af1da6 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Merge remote-tracking branch 'internal/andriy/lwpck-2788' into andriy/lwpck-2788

parents 2bef5501 60b885ae
...@@ -541,7 +541,7 @@ endif() ...@@ -541,7 +541,7 @@ endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
add_compile_options(-fcolor-diagnostics) # add_compile_options(-fcolor-diagnostics)
endif() endif()
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9)
add_compile_options(-fdiagnostics-color=always) add_compile_options(-fdiagnostics-color=always)
......
{
"version": 3,
"configurePresets": [
{
"name": "linux-debug",
"displayName": "Linux Debug",
"hidden": true,
"generator": "Unix Makefiles",
"binaryDir": "${sourceDir}/build/${presetName}",
"installDir": "${sourceDir}/build/install/${presetName}",
"environment": {
"MY_ENVIRONMENT_VARIABLE": "NONE",
"PATH": "/usr/local/.cargo/bin:$penv{PATH}",
"SCCACHE_IDLE_TIMEOUT": "11000"
},
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
"BUILD_DEV": "ON",
"CMAKE_CXX_COMPILER": "/opt/rocm/bin/hipcc",
"CMAKE_PREFIX_PATH": "/opt/rocm",
"CMAKE_CXX_COMPILER_LAUNCHER": "sccache",
"CMAKE_C_COMPILER_LAUNCHER": "sccache"
},
"condition": {
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Linux"
}
},
{
"name": "MI355-debug",
"displayName": "MI355 Debug",
"inherits": "linux-debug",
"description": "Development Environment for MI355.",
"cacheVariables": {
"GPU_TARGETS": "gfx950",
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_CXX_FLAGS": "-O0 -ggdb"
}
},
{
"name": "MI355-release",
"displayName": "MI355 Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx950",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3"
}
},
{
"name": "MI300X-release",
"displayName": "MI300X Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx942",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3"
}
},
{
"name": "MI250-release",
"displayName": "MI250 Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx90a",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3",
"CK_USE_FP8_ON_UNSUPPORTED_ARCH":"ON"
}
},
{
"name": "MI250-debug",
"displayName": "MI250 Debug",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx90a",
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_CXX_FLAGS": "-O0 -ggdb",
"CK_USE_FP8_ON_UNSUPPORTED_ARCH":"ON"
}
},
{
"name": "RX7800-release",
"displayName": "RX7800 Release",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx1101",
"DL_KERNELS": "ON",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_FLAGS": "-O3"
}
},
{
"name": "RX7800-debug",
"displayName": "RX7800 Debug",
"inherits": "linux-debug",
"cacheVariables": {
"GPU_TARGETS": "gfx1101",
"DL_KERNELS": "ON",
"CMAKE_BUILD_TYPE": "Debug",
"CMAKE_CXX_FLAGS": "-O0 -ggdb"
}
}
],
"buildPresets": [
{
"name": "Debug",
"hidden": true,
"configuration": "Debug"
},
{
"name": "Release",
"hidden": true,
"configuration": "Release"
},
{
"name": "MI355-debug",
"displayName": "MI355",
"configurePreset": "MI355-debug",
"description": "Build Environment for MI355 Debug.",
"inherits": [
"Debug"
],
"jobs": 128
},
{
"name": "MI355-release",
"displayName": "MI355",
"configurePreset": "MI355-release",
"description": "Build Environment for MI355 Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "MI300X-release",
"displayName": "MI300X",
"configurePreset": "MI300X-release",
"description": "Build Environment for MI300X Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "MI250-release",
"displayName": "MI250",
"configurePreset": "MI250-release",
"description": "Build Environment for MI250 Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "MI250-debug",
"displayName": "MI250",
"configurePreset": "MI250-debug",
"description": "Build Environment for MI250 Debug.",
"inherits": [
"Debug"
],
"jobs": 128
},
{
"name": "RX7800-release",
"displayName": "RX7800",
"configurePreset": "RX7800-release",
"description": "Build Environment for RX7800 Release.",
"inherits": [
"Release"
],
"jobs": 128
},
{
"name": "RX7800-debug",
"displayName": "RX7800",
"configurePreset": "RX7800-debug",
"description": "Build Environment for RX7800 Debug.",
"inherits": [
"Debug"
],
"jobs": 128
}
]
}
...@@ -359,6 +359,21 @@ struct GeneratorTensor_Sequential ...@@ -359,6 +359,21 @@ struct GeneratorTensor_Sequential
} }
}; };
template <ck::index_t Dim>
struct GeneratorTensor_Sequential<ck::e8m0_bexp_t, Dim>
{
int offset = 0;
template <typename... Ts>
ck::e8m0_bexp_t operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
int tmp = dims[Dim];
return ck::type_convert<ck::e8m0_bexp_t>(powf(2, tmp + offset));
}
};
template <typename T, size_t NumEffectiveDim = 2> template <typename T, size_t NumEffectiveDim = 2>
struct GeneratorTensor_Diagonal struct GeneratorTensor_Diagonal
{ {
......
...@@ -780,7 +780,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8> ...@@ -780,7 +780,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
} }
}; };
// TODO: fix mfma...f8f6f4 instructions
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4> struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
{ {
...@@ -847,9 +846,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4> ...@@ -847,9 +846,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
// clang-format on // clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a,
const int32_t& scale_a,
const FloatB& b,
const int32_t& scale_b,
FloatC& reg_c) const
{ {
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, scale_a, b, scale_b, reg_c);
} }
}; };
...@@ -871,9 +875,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4> ...@@ -871,9 +875,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
// clang-format on // clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a,
{ const int32_t& scale_a,
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); const FloatB& b,
const int32_t& scale_b,
FloatC& reg_c) const
{
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, scale_a, b, scale_b, reg_c);
} }
}; };
......
...@@ -519,12 +519,36 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> ...@@ -519,12 +519,36 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a, __device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a, const int32_t& scale_a,
const f8x32_t& reg_b, const f8x32_t& reg_b,
const int32_t scale_b, const int32_t& scale_b,
FloatC& reg_c) FloatC& reg_c)
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
if(threadIdx.x == 0 || threadIdx.x == 32)
{
printf("thread: %u -- xA: %x\n", threadIdx.x, static_cast<uint32_t>(scale_a));
printf("thread: %u -- xB: %x\n", threadIdx.x, static_cast<uint32_t>(scale_b));
// printf("intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> thread: %u -- scale_a: %f\n",
// threadIdx.x,
// static_cast<float>(ck::e8m0_bexp_t(scale_a)));
// printf("intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> thread: %u -- scale_b: %f\n",
// threadIdx.x,
// static_cast<float>(ck::e8m0_bexp_t(scale_b)));
// for(size_t i = 0; i < 32; i++)
// {
// printf("thread: %u -- reg_a[%zu]: %f\n",
// threadIdx.x,
// i,
// type_convert<float>(f8_t{static_cast<f8x32_t::data_v>(reg_a)[i]}));
// // printf("thread: %u -- reg_a[%zu]: %f\n",
// // threadIdx.x,
// // i,
// // type_convert<float>(f8_t{static_cast<f8x32_t::data_v>(reg_b)[i]}));
// }
}
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
......
...@@ -30,11 +30,11 @@ bool run_mfma_test(ck::index_t init) ...@@ -30,11 +30,11 @@ bool run_mfma_test(ck::index_t init)
constexpr auto BLOCK_N = mfma_instr.n_per_blk; constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
const auto mx_mfma_kernel = ck::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>; const auto mfma_kernel = ck::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
bool pass = true; bool pass = true;
pass = ck::mfma_test::TestMFMA<decltype(mx_mfma_kernel), pass = ck::mfma_test::TestMFMA<decltype(mfma_kernel),
AType, AType,
BType, BType,
CType, CType,
...@@ -45,7 +45,7 @@ bool run_mfma_test(ck::index_t init) ...@@ -45,7 +45,7 @@ bool run_mfma_test(ck::index_t init)
CLayout, CLayout,
BLOCK_M, BLOCK_M,
BLOCK_N, BLOCK_N,
BLOCK_K>{}(mx_mfma_kernel, init); BLOCK_K>{}(mfma_kernel, init);
return pass; return pass;
} }
...@@ -63,3 +63,98 @@ TEST(MFMA, FP8MFMA32x32x64) ...@@ -63,3 +63,98 @@ TEST(MFMA, FP8MFMA32x32x64)
auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init); auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
/**
* @brief Run the test for the given MX MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
bool run_mxmfma_test(ck::index_t init)
{
static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 ||
mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64,
"Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported");
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AccType = float; // only MFMA_F32 instructions supported
// using CPUAccType = AccType;
using ScaleType = ck::e8m0_bexp_t; // biased exponent type
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
constexpr auto BLOCK_X = 32; // scaling vector size
const auto mx_mfma_kernel =
ck::matmul<AType, BType, ScaleType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_X>;
bool pass = true;
pass = ck::mxmfma_test::TestMXMFMA<decltype(mx_mfma_kernel),
AType,
BType,
ScaleType,
CType,
ALayout,
BLayout,
CLayout,
BLOCK_M,
BLOCK_N,
BLOCK_K,
BLOCK_X>{}(mx_mfma_kernel, init);
return pass;
}
TEST(MXMFMA, MXFP8MFMA16x16x128i2)
{
auto AB_init = 2;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64i2)
{
auto AB_init = 2;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA16x16x128i3)
{
auto AB_init = 3;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64i3)
{
auto AB_init = 3;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA16x16x128i4)
{
auto AB_init = 4;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64i4)
{
auto AB_init = 4;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64i5)
{
auto AB_init = 5;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
...@@ -18,7 +18,13 @@ enum class MFMA_F8F6F4 ...@@ -18,7 +18,13 @@ enum class MFMA_F8F6F4
F32_16x16x128 = F32_16x16x128 =
static_cast<int>(MfmaInstr::mfma_f32_16x16x128f8f6f4), // V_MFMA_F32_16X16X128_F8F6F4 static_cast<int>(MfmaInstr::mfma_f32_16x16x128f8f6f4), // V_MFMA_F32_16X16X128_F8F6F4
F32_32x32x64 = F32_32x32x64 =
static_cast<int>(MfmaInstr::mfma_f32_32x32x64f8f6f4) // V_MFMA_F32_32X32X64_F8F6F4 static_cast<int>(MfmaInstr::mfma_f32_32x32x64f8f6f4), // V_MFMA_F32_32X32X64_F8F6F4
SCALE_F32_16x16x128 = static_cast<int>(
MfmaInstr::mfma_scale_f32_16x16x128f8f6f4), // V_MFMA_SCALE_F32_16X16X128_F8F6F4
SCALE_F32_32x32x64 = static_cast<int>(
MfmaInstr::mfma_scale_f32_32x32x64f8f6f4) // V_MFMA_SCALE_F32_32X32X64_F8F6F4
}; };
template <typename AFragT, typename BFragT, typename AccumFragT, int32_t BLOCK_M, int32_t BLOCK_N> template <typename AFragT, typename BFragT, typename AccumFragT, int32_t BLOCK_M, int32_t BLOCK_N>
...@@ -32,6 +38,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16> ...@@ -32,6 +38,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{}; auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
} }
__device__ void operator()(AFragT const& fragA,
const int32_t& scale_a,
BFragT const& fragB,
const int32_t& scale_b,
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(
fragA, scale_a, fragB, scale_b, fragAcc);
}
}; };
template <typename AFragT, typename BFragT, typename AccumFragT> template <typename AFragT, typename BFragT, typename AccumFragT>
...@@ -42,6 +59,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32> ...@@ -42,6 +59,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{}; auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
} }
__device__ void operator()(AFragT const& fragA,
const int32_t& scale_a,
BFragT const& fragB,
const int32_t& scale_b,
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(
fragA, scale_a, fragB, scale_b, fragAcc);
}
}; };
template <typename VecT> template <typename VecT>
...@@ -131,11 +159,121 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) ...@@ -131,11 +159,121 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
return fragA; return fragA;
} }
// Define a load function for input A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in row major format
// This means:
// - From A we will load BLOCK_M rows of size K to satisfy our input data
template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
__device__ AFragT load_A_row_major(AType const* input_ptr)
{
// clang-format off
// Register Mapping for 16x128: || Register Mapping for 32x64:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
// clang-format on
// Here we want to load a BLOCK_M x BLOCK_K block of data.
static constexpr uint32_t VW = vectorSize(AFragT{});
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
(threadIdx.x / BLOCK_M) * VW); // Col
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
// BLOCK_K is a stride in A matrix
auto startOffset = row_major(startCoord2D, BLOCK_K);
auto const* fragPtr = reinterpret_cast<AFragT const*>(input_ptr + startOffset);
return *fragPtr;
}
// Define a load function for scaled A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in row major format
// - The scale inputs distributed across 64 lanes.
// This means:
// - From A we will load BLOCK_M rows of size K to satisfy our input data
template <typename AType,
typename AFragT,
typename ScaleType,
typename ScaleFragT,
int32_t BLOCK_M,
int32_t BLOCK_K,
int32_t BLOCK_X>
__device__ AFragT load_mx_A_row_major(AType const* input_ptr,
ScaleType const* scale_ptr,
ScaleFragT& fragX)
{
static constexpr uint32_t VW = vectorSize(AFragT{});
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 1 element
// We need to know where they start
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
(threadIdx.x / BLOCK_M) * VW / BLOCK_X); // Col
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
// BLOCK_K / BLOCK_X is a stride in xA matrix
auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X);
// preserve upper bits obtain 8-bit exponent
fragX = (fragX & 0xFFFFFF00) | (utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF);
return load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(input_ptr);
}
// Define a load function for input B blocks: // Define a load function for input B blocks:
// Size: (BLOCK_K x BLOCK_N) // Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION: // ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register. // - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in row_major format // - Data is in column major format
// This means: // This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data // - From B we will load K rows of size BLOCK_N to satisfy our input data
template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N> template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N>
...@@ -199,6 +337,46 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) ...@@ -199,6 +337,46 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
return *fragPtr; return *fragPtr;
} }
// Define a load function for scaled B blocks:
// Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in column major format
// - The scale inputs distributed across 64 lanes.
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template <typename BType,
typename BFragT,
typename ScaleType,
typename ScaleFragT,
int32_t BLOCK_K,
int32_t BLOCK_N,
int32_t BLOCK_X>
__device__ BFragT load_mx_B_col_major(BType const* input_ptr,
ScaleType const* scale_ptr,
ScaleFragT& fragX)
{
static constexpr uint32_t VW = vectorSize(BFragT{});
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 1 element
// We need to know where to start
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW / BLOCK_X, // Row
threadIdx.x % BLOCK_N); // Col
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X);
// preserve upper bits obtain 8-bit exponent
fragX = (fragX & 0xFFFFFF00) | (utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF);
return load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(input_ptr);
}
// Define a store function for C // Define a store function for C
// Size: (BLOCK_M x BLOCK_N) // Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION: // ASSUMPTION:
...@@ -309,6 +487,129 @@ struct store_C_col_major<CType, CFragT, 32, 32> ...@@ -309,6 +487,129 @@ struct store_C_col_major<CType, CFragT, 32, 32>
} }
}; };
// Define a store function for C
// Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in row major format
template <typename CType, typename CFragT, int32_t BLOCK_M, int32_t BLOCK_N>
struct store_C_row_major;
// Here we want to store a 16x16 block of data.
//
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector
// Register Element ------------ ------------- ------------ -------------- Element
// Reg0 | M0 | M4 | M8 | M12 | v[0]
// Reg1 | M1 | M5 | M9 | M13 | v[1]
// Reg2 | M2 | M6 | M10 | M14 | v[2]
// Reg3 | M3 | M7 | M11 | M15 | v[3]
template <typename CType, typename CFragT>
struct store_C_row_major<CType, CFragT, 16, 16>
{
__device__ void operator()(CType* output, CFragT cFrag)
{
static constexpr uint32_t VW = vectorSize(cFrag); // 4
static constexpr uint32_t Dim = 16;
// Each thread will load 4 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
auto stepCoord2D = std::make_pair(1u, 0u);
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
auto startOffset = row_major(startCoord2D, 16);
auto kOffset = row_major(stepCoord2D, 16);
auto* fragPtr = reinterpret_cast<CFragT*>(output + startOffset);
*fragPtr = cFrag;
// If you notice carefully, kOffset != 1.
// This means the following is vector is updated with 4 non-contiguous offsets,
// which the compiler will separate into 4 different global_store_dword instructions.
output[startOffset] = cFrag[0]; // v[0] = Reg 0
output[startOffset + kOffset] = cFrag[1]; // v[1] = Reg 1
output[startOffset + 2 * kOffset] = cFrag[2]; // v[2] = Reg 2
output[startOffset + 3 * kOffset] = cFrag[3]; // v[3] = Reg 3
}
};
// Here we want to store a 32x32 block of data.
// Register Mapping:
// Size | BLOCK_N | BLOCK_N |
// N | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- Element
// Reg0 | M0 | M4 | v[0]
// Reg1 | M1 | M5 | v[1]
// Reg2 | M2 | M6 | v[2]
// Reg3 | M3 | M7 | v[3]
// ____________ _____________
// Reg4 | M8 | M12 | v[4]
// Reg5 | M9 | M13 | v[5]
// Reg6 | M10 | M14 | v[6]
// Reg7 | M11 | M15 | v[7]
// ____________ _____________
// Reg8 | M16 | M20 | v[8]
// Reg9 | M17 | M21 | v[9]
// Reg10 | M18 | M22 | v[10]
// Reg11 | M19 | M23 | v[11]
// ____________ _____________
// Reg12 | M24 | M28 | v[12]
// Reg13 | M25 | M29 | v[13]
// Reg14 | M26 | M30 | v[14]
// Reg15 | M27 | M31 | v[15]
template <typename CType, typename CFragT>
struct store_C_row_major<CType, CFragT, 32, 32>
{
__device__ void operator()(CType* output, CFragT cFrag)
{
static constexpr uint32_t WAVE_SIZE = 64;
static constexpr uint32_t VW = 4; // This VW is per 'chunk'
static constexpr uint32_t Dim = 32; // BLOCK_N
static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
// Minor step for each 'chunk'
auto minorStepCoord2D = std::make_pair(1u, 0u);
// Major step between 'chunks'
auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0);
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
auto startOffset = row_major(startCoord2D, 32);
auto kMinorOffset = row_major(minorStepCoord2D, 32);
auto kMajorOffset = row_major(majorStepCoord2D, 32);
output[startOffset] = cFrag[0]; // v[0] = Reg 0
output[startOffset + kMinorOffset] = cFrag[1]; // v[1] = Reg 1
output[startOffset + 2 * kMinorOffset] = cFrag[2]; // v[2] = Reg 2
output[startOffset + 3 * kMinorOffset] = cFrag[3]; // v[3] = Reg 3
output[startOffset + kMajorOffset] = cFrag[4]; // v[4] = Reg 4
output[startOffset + kMajorOffset + kMinorOffset] = cFrag[5]; // v[5] = Reg 5
output[startOffset + kMajorOffset + 2 * kMinorOffset] = cFrag[6]; // v[6] = Reg 6
output[startOffset + kMajorOffset + 3 * kMinorOffset] = cFrag[7]; // v[7] = Reg 7
output[startOffset + 2 * kMajorOffset] = cFrag[8]; // v[8] = Reg 8
output[startOffset + 2 * kMajorOffset + kMinorOffset] = cFrag[9]; // v[9] = Reg 9
output[startOffset + 2 * kMajorOffset + 2 * kMinorOffset] = cFrag[10]; // v[10] = Reg 10
output[startOffset + 2 * kMajorOffset + 3 * kMinorOffset] = cFrag[11]; // v[11] = Reg 11
output[startOffset + 3 * kMajorOffset] = cFrag[12]; // v[12] = Reg 12
output[startOffset + 3 * kMajorOffset + kMinorOffset] = cFrag[13]; // v[13] = Reg 13
output[startOffset + 3 * kMajorOffset + 2 * kMinorOffset] = cFrag[14]; // v[14] = Reg 14
output[startOffset + 3 * kMajorOffset + 3 * kMinorOffset] = cFrag[15]; // v[15] = Reg 15
}
};
template <typename AType, template <typename AType,
typename BType, typename BType,
typename CType, typename CType,
...@@ -342,7 +643,9 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) ...@@ -342,7 +643,9 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
// Matrix multiply-accumulate using MFMA units // Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N // Accumulation intermediate = BLOCK_M x BLOCK_N
__syncthreads();
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc); mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);
__syncthreads();
for(int i = 0; i < vectorSize(fragC); ++i) for(int i = 0; i < vectorSize(fragC); ++i)
{ {
...@@ -352,6 +655,139 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) ...@@ -352,6 +655,139 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{}; auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC); storeC(c, fragC);
} }
template <typename AType,
typename BType,
typename ScaleType,
typename CType,
typename AccType,
int32_t BLOCK_M,
int32_t BLOCK_N,
int32_t BLOCK_K,
int32_t BLOCK_X>
__global__ void
matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c)
{
constexpr int WAVE_SIZE = 64;
assert(threadIdx.x < WAVE_SIZE);
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using ScaleFragT = int32_t;
// Create frags
auto fragA = AFragT{};
auto fragB = BFragT{};
auto fragC = CFragT{};
auto fragAcc = AccumFragT{0};
auto fragXa = ScaleFragT{0};
auto fragXb = ScaleFragT{0};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, ScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
a, xa, fragXa);
// B = col major, BLOCK_K x BLOCK_N
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, ScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
b, xb, fragXb);
// Scaled Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
__syncthreads();
// printf("thread: %u -- fragXa: %d\n", threadIdx.x, fragXa);
printf("thread: %u -- fragA: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x\n",
threadIdx.x,
fragA.data_.dN[0],
fragA.data_.dN[1],
fragA.data_.dN[2],
fragA.data_.dN[3],
fragA.data_.dN[4],
fragA.data_.dN[5],
fragA.data_.dN[6],
fragA.data_.dN[7],
fragA.data_.dN[8],
fragA.data_.dN[9],
fragA.data_.dN[10],
fragA.data_.dN[11],
fragA.data_.dN[12],
fragA.data_.dN[13],
fragA.data_.dN[14],
fragA.data_.dN[15],
fragA.data_.dN[16],
fragA.data_.dN[17],
fragA.data_.dN[18],
fragA.data_.dN[19],
fragA.data_.dN[20],
fragA.data_.dN[21],
fragA.data_.dN[22],
fragA.data_.dN[23],
fragA.data_.dN[24],
fragA.data_.dN[25],
fragA.data_.dN[26],
fragA.data_.dN[27],
fragA.data_.dN[28],
fragA.data_.dN[29],
fragA.data_.dN[30],
fragA.data_.dN[31]);
printf("thread: %u -- fragB: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x\n",
threadIdx.x,
fragB.data_.dN[0],
fragB.data_.dN[1],
fragB.data_.dN[2],
fragB.data_.dN[3],
fragB.data_.dN[4],
fragB.data_.dN[5],
fragB.data_.dN[6],
fragB.data_.dN[7],
fragB.data_.dN[8],
fragB.data_.dN[9],
fragB.data_.dN[10],
fragB.data_.dN[11],
fragB.data_.dN[12],
fragB.data_.dN[13],
fragB.data_.dN[14],
fragB.data_.dN[15],
fragB.data_.dN[16],
fragB.data_.dN[17],
fragB.data_.dN[18],
fragB.data_.dN[19],
fragB.data_.dN[20],
fragB.data_.dN[21],
fragB.data_.dN[22],
fragB.data_.dN[23],
fragB.data_.dN[24],
fragB.data_.dN[25],
fragB.data_.dN[26],
fragB.data_.dN[27],
fragB.data_.dN[28],
fragB.data_.dN[29],
fragB.data_.dN[30],
fragB.data_.dN[31]);
//__builtin_amdgcn_mfma_ld_scale_b32(fragXa, 0, 0);
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(
fragA, fragXa, fragB, fragXb, fragAcc);
__syncthreads();
for(int i = 0; i < vectorSize(fragC); ++i)
{
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
}
__syncthreads();
auto storeC = store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC);
}
/** /**
* @brief Structure to hold dimension parameters for GEMM tensors. * @brief Structure to hold dimension parameters for GEMM tensors.
* *
...@@ -373,6 +809,384 @@ struct GemmParams ...@@ -373,6 +809,384 @@ struct GemmParams
ck::index_t StrideC = -1; ck::index_t StrideC = -1;
}; };
namespace mxmfma_test {
template <typename ADataType, typename BDataType, typename ScaleType, typename CDataType>
void RunHostGEMM(const Tensor<ADataType>& A,
const Tensor<ScaleType>& a_scales,
const Tensor<BDataType>& B,
const Tensor<ScaleType>& b_scales,
Tensor<CDataType>& C)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using GemmInstance = ck::tensor_operation::host::ReferenceGemm<float,
float,
CDataType,
float,
PassThrough,
PassThrough,
PassThrough,
float,
float>;
Tensor<float> a_m_k(A.mDesc);
Tensor<float> b_k_n(B.mDesc);
const auto M = A.mDesc.GetLengths()[0];
const auto N = B.mDesc.GetLengths()[1];
const auto K = A.mDesc.GetLengths()[1];
const auto BLOCK_X = K / a_scales.mDesc.GetLengths()[1];
for(size_t m = 0; m < M; m++)
{
for(size_t k = 0; k < K; k++)
{
a_m_k(m, k) =
type_convert<float>(A(m, k)) * type_convert<float>(a_scales(m, k / BLOCK_X));
}
}
for(size_t n = 0; n < N; n++)
{
for(size_t k = 0; k < K; k++)
{
b_k_n(k, n) =
type_convert<float>(B(k, n)) * type_convert<float>(b_scales(k / BLOCK_X, n));
}
}
auto ref_gemm = GemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, C, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
}
template <typename KernelType,
typename ADataType,
typename BDataType,
typename ScaleType,
typename CDataType>
bool RunDeviceGEMM(KernelType kernel,
const Tensor<ADataType>& A,
const Tensor<ScaleType>& a_scales,
const Tensor<BDataType>& B,
const Tensor<ScaleType>& b_scales,
Tensor<CDataType>& C)
{
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem a_scales_device_buf(sizeof(ScaleType) * a_scales.mDesc.GetElementSpaceSize());
DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
DeviceMem b_scales_device_buf(sizeof(ScaleType) * b_scales.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(A.mData.data());
a_scales_device_buf.ToDevice(a_scales.mData.data());
b_n_k_device_buf.ToDevice(B.mData.data());
b_scales_device_buf.ToDevice(b_scales.mData.data());
kernel<<<1, 64>>>(static_cast<const ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<const ScaleType*>(a_scales_device_buf.GetDeviceBuffer()),
static_cast<const BDataType*>(b_n_k_device_buf.GetDeviceBuffer()),
static_cast<const ScaleType*>(b_scales_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()));
c_m_n_device_buf.FromDevice(C.mData.data());
return true;
}
template <typename DeviceMFMA,
typename ADataType,
typename BDataType,
typename ScaleType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
index_t BLOCK_M,
index_t BLOCK_N,
index_t BLOCK_K,
index_t BLOCK_X>
struct TestMXMFMA
{
auto PrepareGemmTensors(const GemmParams& params, index_t init)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<ScaleType> a_scales(
f_host_tensor_descriptor(params.M, params.K / BLOCK_X, params.K / BLOCK_X, ALayout{}));
Tensor<BDataType> b_n_k(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<ScaleType> b_scales(
f_host_tensor_descriptor(params.K / BLOCK_X, params.N, params.K / BLOCK_X, BLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
switch(init)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{0.015625f}});
// NOTE: not all numbers are representable in FP8, BF8, etc.
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
break;
case 1:
// results in C = {K}
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{512.0f}});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f / 512}});
break;
case 2:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{127, 129}); // 1, 2 // scales: {0.5, 1, 2}
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
break;
case 3:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{128, 129}); // 2
// a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
break;
case 4:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.3});
a_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 128}); // 1, 2
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
break;
case 5:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.0});
for(size_t i = 0; i < 32; i++)
{
a_m_k(0, i) = type_convert<ADataType>(1.0f);
}
for(size_t i = 32; i < 64; i++)
{
a_m_k(0, i) = type_convert<ADataType>(-2.0f);
}
// printf("f8 1: %x \n", type_convert<ADataType>(1.0f).data);
// printf("f8 -2: %x \n", type_convert<ADataType>(-2.0f).data);
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
a_scales(0, 0) = ScaleType{1.0f};
a_scales(0, 1) = ScaleType{0.5f};
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{0.0f});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
for(size_t i = 0; i < 64; i++)
{
b_n_k(i, 0) = type_convert<BDataType>(1.0f);
}
break;
// case 3:
// // expect small round off errors
// a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
// a_scales.GenerateTensorValue(
// GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
// b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
// b_scales.GenerateTensorValue(
// GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
// break;
// case 4:
// a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
// a_scales.GenerateTensorValue(GeneratorTensor_Sequential<ScaleType, 0>{-9});
// b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
// b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
// break;
// case 5:
// a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
// a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
// b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
// b_scales.GenerateTensorValue(GeneratorTensor_Sequential<ScaleType, 1>{-9});
// break;
case 6:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.00195312f});
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f / 16}});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
b_scales.GenerateTensorValue(GeneratorTensor_Sequential<ScaleType, 1>{-9});
break;
default:
// all initial values are representable in FP8, BF8
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});
// a_scales.GenerateTensorValue(GeneratorTensor_3<ScaleType>{1.0f / 32.0f, 1.0f});
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6});
// b_scales.GenerateTensorValue(GeneratorTensor_3<ScaleType>{1.0f / 32.0f, 1.0f});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
break;
}
return std::make_tuple(
a_m_k, a_scales, b_n_k, b_scales, c_m_n_host_result, c_m_n_device_result);
}
auto operator()(const DeviceMFMA& mfma_kernel, index_t init)
{
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl;
// Arrange
GemmParams params;
params.M = BLOCK_M;
params.N = BLOCK_N;
params.K = BLOCK_K;
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
ck::index_t stride,
auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
params.StrideA = f_get_default_stride(BLOCK_M, BLOCK_K, params.StrideA, ALayout{});
params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{});
params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{});
auto host_tensors = PrepareGemmTensors(params, init);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<ScaleType>& a_scales = std::get<1>(host_tensors);
const Tensor<BDataType>& b = std::get<2>(host_tensors);
const Tensor<ScaleType>& b_scales = std::get<3>(host_tensors);
Tensor<CDataType>& c_host = std::get<4>(host_tensors);
Tensor<CDataType>& c_device = std::get<5>(host_tensors);
RunHostGEMM(a, a_scales, b, b_scales, c_host);
RunDeviceGEMM(mfma_kernel, a, a_scales, b, b_scales, c_device);
#if 0
#if 1
std::cout << "a:" << std::endl;
for(size_t i = 0; i < BLOCK_M; i++)
{
for(size_t j = 0; j < BLOCK_K; j++)
{
std::cout << type_convert<float>(a(i, j)) << " ";
}
std::cout << std::endl;
break;
}
// std::cout << "b:" << std::endl;
// for(size_t i = 0; i < BLOCK_K; i++)
// {
// for(size_t j = 0; j < BLOCK_N; j++)
// {
// if(j == 0)
// std::cout << type_convert<float>(b(i, j)) << " ";
// }
// std::cout << std::endl;
// }
#endif
#if 0
std::cout << "a_scale:" << std::endl;
for(size_t i = 0; i < BLOCK_M; i++)
{
for(size_t j = 0; j < BLOCK_K / BLOCK_X; j++)
{
std::cout << type_convert<float>(a_scales(i, j)) << " ";
}
std::cout << std::endl;
}
// std::cout << "b_scale:" << std::endl;
// for(size_t i = 0; i < BLOCK_K / BLOCK_X; i++)
// {
// for(size_t j = 0; j < BLOCK_N; j++)
// {
// std::cout << type_convert<float>(b_scales(i, j)) << " ";
// }
// std::cout << std::endl;
// }
#endif
std::cout << "c_device:" << std::endl;
for(size_t i = 0; i < BLOCK_M; i++)
{
for(size_t j = 0; j < BLOCK_N; j++)
{
std::cout << type_convert<float>(c_device(i, j)) << " ";
}
std::cout << std::endl;
break;
}
#endif
bool res = false;
if constexpr(std::is_same<CDataType, float>::value ||
std::is_same<CDataType, half_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
if(!res)
{
std::cout << "c_host:" << std::endl;
for(size_t i = 0; i < BLOCK_M; i++)
{
for(size_t j = 0; j < BLOCK_N; j++)
{
std::cout << type_convert<float>(c_host(i, j)) << " ";
}
std::cout << std::endl;
break;
}
}
}
else
{
std::cout << "UNSUPPORTED CDataType" << std::endl;
}
return res;
}
};
} // namespace mxmfma_test
namespace mfma_test { namespace mfma_test {
template <typename GemmInstance, template <typename GemmInstance,
typename ADataType, typename ADataType,
......
#include <hip/hip_ext.h>
#include <hip/hip_runtime.h>
__global__ void kernel()
{
using dataAB = uint8_t __attribute__((ext_vector_type(32)));
using dataC = float __attribute__((ext_vector_type(16)));
using dataX = int32_t __attribute__((ext_vector_type(2)));
dataAB regA(0x38);
dataAB regB(0x38);
dataC regC(1.0f);
// dataC regCin(1.0f);
#if 1
// dataX xa{127, 127}; // 1.0
dataX xa(127 & 0xFF); // 1.0
dataX xb(127 & 0xFF); // 1.0
#else
dataX xa(0);
dataX xb(0);
#endif
#if 0
if(threadIdx.x == 0)
{
// xa = 127; // 1.0
for(size_t i = 0; i < 32; i++)
{
regA[i] = 0x38; // 1.0
}
for(size_t i = 0; i < 32; i++)
{
regB[i] = 0x38; // 1.0
}
printf("thread: %u -- xA: %x\n", threadIdx.x, xa[threadIdx.x / 32]);
printf("thread: %u -- xB: %x\n", threadIdx.x, xb[threadIdx.x / 32]);
}
if(threadIdx.x == 32)
{
// xa = 126; // 0.5
for(size_t i = 0; i < 32; i++)
{
regA[i] = 0xC0; // -2.0
}
for(size_t i = 0; i < 32; i++)
{
regB[i] = 0x38; // 1.0
}
printf("thread: %u -- xA: %x\n", threadIdx.x, xa[threadIdx.x / 32]);
printf("thread: %u -- xB: %x\n", threadIdx.x, xb[threadIdx.x / 32]);
}
#endif
__syncthreads();
printf("thread: %u -- regA: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x\n",
threadIdx.x,
regA[0],
regA[1],
regA[2],
regA[3],
regA[4],
regA[5],
regA[6],
regA[7],
regA[8],
regA[9],
regA[10],
regA[11],
regA[12],
regA[13],
regA[14],
regA[15],
regA[16],
regA[17],
regA[18],
regA[19],
regA[20],
regA[21],
regA[22],
regA[23],
regA[24],
regA[25],
regA[26],
regA[27],
regA[28],
regA[29],
regA[30],
regA[31]);
printf("thread: %u -- regB: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x\n",
threadIdx.x,
regB[0],
regB[1],
regB[2],
regB[3],
regB[4],
regB[5],
regB[6],
regB[7],
regB[8],
regB[9],
regB[10],
regB[11],
regB[12],
regB[13],
regB[14],
regB[15],
regB[16],
regB[17],
regB[18],
regB[19],
regB[20],
regB[21],
regB[22],
regB[23],
regB[24],
regB[25],
regB[26],
regB[27],
regB[28],
regB[29],
regB[30],
regB[31]);
//__builtin_amdgcn_mfma_ld_scale_b32(xb[threadIdx.x / 32], 0, 0);
regC = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(regA,
regB,
regC,
0, // cbsz
0, // blgp
0,
xa[threadIdx.x / 32],
0,
xb[threadIdx.x / 32]);
__syncthreads();
printf("thread: %u -- regC: %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f\n",
threadIdx.x,
regC[0],
regC[1],
regC[2],
regC[3],
regC[4],
regC[5],
regC[6],
regC[7],
regC[8],
regC[9],
regC[10],
regC[11],
regC[12],
regC[13],
regC[14],
regC[15]);
// printf("thread: %u -- regCin: %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f\n",
// threadIdx.x,
// regCin[0],
// regCin[1],
// regCin[2],
// regCin[3],
// regCin[4],
// regCin[5],
// regCin[6],
// regCin[7],
// regCin[8],
// regCin[9],
// regCin[10],
// regCin[11],
// regCin[12],
// regCin[13],
// regCin[14],
// regCin[15]);
}
int main()
{
kernel<<<1, 64>>>();
return 0;
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment