Commit a24c1b01 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

WIP: Completed implementation for MX FP8 MFMA

parent 465ba138
...@@ -847,9 +847,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4> ...@@ -847,9 +847,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
// clang-format on // clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a,
const int32_t& scale_a,
const FloatB& b,
const int32_t& scale_b,
FloatC& reg_c) const
{ {
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, scale_a, b, scale_b, reg_c);
} }
}; };
...@@ -871,9 +876,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4> ...@@ -871,9 +876,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
// clang-format on // clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a,
{ const int32_t& scale_a,
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); const FloatB& b,
const int32_t& scale_b,
FloatC& reg_c) const
{
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, scale_a, b, scale_b, reg_c);
} }
}; };
......
...@@ -79,9 +79,9 @@ bool run_mxmfma_test(ck::index_t init) ...@@ -79,9 +79,9 @@ bool run_mxmfma_test(ck::index_t init)
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
using AccType = float; // only MFMA_F32 instructions supported using AccType = float; // only MFMA_F32 instructions supported
using CPUAccType = AccType; // using CPUAccType = AccType;
using ScaleType = ck::e8m0_bexp_t; // biased exponent type using ScaleType = ck::e8m0_bexp_t; // biased exponent type
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr; ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
constexpr auto BLOCK_M = mfma_instr.m_per_blk; constexpr auto BLOCK_M = mfma_instr.m_per_blk;
......
...@@ -38,6 +38,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16> ...@@ -38,6 +38,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{}; auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
} }
__device__ void operator()(AFragT const& fragA,
const int32_t& scale_a,
BFragT const& fragB,
const int32_t& scale_b,
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(
fragA, scale_a, fragB, scale_b, fragAcc);
}
}; };
template <typename AFragT, typename BFragT, typename AccumFragT> template <typename AFragT, typename BFragT, typename AccumFragT>
...@@ -48,6 +59,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32> ...@@ -48,6 +59,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{}; auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
} }
__device__ void operator()(AFragT const& fragA,
const int32_t& scale_a,
BFragT const& fragB,
const int32_t& scale_b,
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(
fragA, scale_a, fragB, scale_b, fragAcc);
}
}; };
template <typename VecT> template <typename VecT>
...@@ -137,11 +159,121 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) ...@@ -137,11 +159,121 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
return fragA; return fragA;
} }
// Define a load function for input A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in row major format
// This means:
// - From A we will load BLOCK_M rows of size K to satisfy our input data
template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
__device__ AFragT load_A_row_major(AType const* input_ptr)
{
// clang-format off
// Register Mapping for 16x128: || Register Mapping for 32x64:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
// clang-format on
// Here we want to load a BLOCK_M x BLOCK_K block of data.
static constexpr uint32_t VW = vectorSize(AFragT{});
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
(threadIdx.x / BLOCK_M) * VW); // Col
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
// BLOCK_K is a stride in A matrix
auto startOffset = row_major(startCoord2D, BLOCK_K);
auto const* fragPtr = reinterpret_cast<AFragT const*>(input_ptr + startOffset);
return *fragPtr;
}
// Define a load function for scaled A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in row major format
// - The scale inputs distributed across 64 lanes.
// This means:
// - From A we will load BLOCK_M rows of size K to satisfy our input data
template <typename AType,
typename AFragT,
typename ScaleType,
typename ScaleFragT,
int32_t BLOCK_M,
int32_t BLOCK_K,
int32_t BLOCK_X>
__device__ AFragT load_mx_A_row_major(AType const* input_ptr,
ScaleType const* scale_ptr,
ScaleFragT& fragX)
{
static constexpr uint32_t VW = vectorSize(AFragT{});
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 1 element
// We need to know where they start
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
(threadIdx.x / BLOCK_M) * VW / BLOCK_X); // Col
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
// BLOCK_K / BLOCK_X is a stride in xA matrix
auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X);
// preserve upper bits obtain 8-bit exponent
fragX = (fragX & 0xFFFFFF00) | (utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF);
return load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(input_ptr);
}
// Define a load function for input B blocks: // Define a load function for input B blocks:
// Size: (BLOCK_K x BLOCK_N) // Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION: // ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register. // - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in row_major format // - Data is in column major format
// This means: // This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data // - From B we will load K rows of size BLOCK_N to satisfy our input data
template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N> template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N>
...@@ -205,6 +337,46 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) ...@@ -205,6 +337,46 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
return *fragPtr; return *fragPtr;
} }
// Define a load function for scaled B blocks:
// Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in column major format
// - The scale inputs distributed across 64 lanes.
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template <typename BType,
typename BFragT,
typename ScaleType,
typename ScaleFragT,
int32_t BLOCK_K,
int32_t BLOCK_N,
int32_t BLOCK_X>
__device__ BFragT load_mx_B_col_major(BType const* input_ptr,
ScaleType const* scale_ptr,
ScaleFragT& fragX)
{
static constexpr uint32_t VW = vectorSize(BFragT{});
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 1 element
// We need to know where to start
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW / BLOCK_X, // Row
threadIdx.x % BLOCK_N); // Col
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X);
// preserve upper bits obtain 8-bit exponent
fragX = (fragX & 0xFFFFFF00) | (utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF);
return load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(input_ptr);
}
// Define a store function for C // Define a store function for C
// Size: (BLOCK_M x BLOCK_N) // Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION: // ASSUMPTION:
...@@ -368,12 +540,49 @@ template <typename AType, ...@@ -368,12 +540,49 @@ template <typename AType,
int32_t BLOCK_N, int32_t BLOCK_N,
int32_t BLOCK_K, int32_t BLOCK_K,
int32_t BLOCK_X> int32_t BLOCK_X>
__global__ void matmul(const AType* a, const BType* b, const ScaleType* x, CType* c) __global__ void
matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c)
{ {
ignore = a; constexpr int WAVE_SIZE = 64;
ignore = b; assert(threadIdx.x < WAVE_SIZE);
ignore = x; assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
ignore = c;
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using ScaleFragT = int32_t;
// Create frags
auto fragA = AFragT{};
auto fragB = BFragT{};
auto fragC = CFragT{};
auto fragAcc = AccumFragT{0};
auto fragXa = ScaleFragT{0};
auto fragXb = ScaleFragT{0};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, ScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
a, xa, fragXa);
// B = col major, BLOCK_K x BLOCK_N
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, ScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
b, xb, fragXb);
// Scaled Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(
fragA, fragXa, fragB, fragXb, fragAcc);
for(int i = 0; i < vectorSize(fragC); ++i)
{
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
}
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC);
} }
/** /**
...@@ -443,20 +652,32 @@ void RunHostGEMM(const Tensor<ADataType>& A, ...@@ -443,20 +652,32 @@ void RunHostGEMM(const Tensor<ADataType>& A,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
template <typename KernelType, typename ADataType, typename BDataType, typename CDataType> template <typename KernelType,
typename ADataType,
typename BDataType,
typename ScaleType,
typename CDataType>
bool RunDeviceGEMM(KernelType kernel, bool RunDeviceGEMM(KernelType kernel,
const Tensor<ADataType>& A, const Tensor<ADataType>& A,
const Tensor<ScaleType>& a_scales,
const Tensor<BDataType>& B, const Tensor<BDataType>& B,
const Tensor<ScaleType>& b_scales,
Tensor<CDataType>& C) Tensor<CDataType>& C)
{ {
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem a_scales_device_buf(sizeof(ScaleType) * a_scales.mDesc.GetElementSpaceSize());
DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
DeviceMem b_scales_device_buf(sizeof(ScaleType) * b_scales.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(A.mData.data()); a_m_k_device_buf.ToDevice(A.mData.data());
a_scales_device_buf.ToDevice(a_scales.mData.data());
b_n_k_device_buf.ToDevice(B.mData.data()); b_n_k_device_buf.ToDevice(B.mData.data());
b_scales_device_buf.ToDevice(b_scales.mData.data());
kernel<<<1, 64>>>(static_cast<const ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), kernel<<<1, 64>>>(static_cast<const ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<const ScaleType*>(a_scales_device_buf.GetDeviceBuffer()),
static_cast<const BDataType*>(b_n_k_device_buf.GetDeviceBuffer()), static_cast<const BDataType*>(b_n_k_device_buf.GetDeviceBuffer()),
static_cast<const ScaleType*>(b_scales_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer())); static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()));
c_m_n_device_buf.FromDevice(C.mData.data()); c_m_n_device_buf.FromDevice(C.mData.data());
...@@ -600,7 +821,7 @@ struct TestMXMFMA ...@@ -600,7 +821,7 @@ struct TestMXMFMA
RunHostGEMM(a, a_scales, b, b_scales, c_host); RunHostGEMM(a, a_scales, b, b_scales, c_host);
RunDeviceGEMM(mfma_kernel, a, b, c_device); RunDeviceGEMM(mfma_kernel, a, a_scales, b, b_scales, c_device);
bool res = false; bool res = false;
if constexpr(std::is_same<CDataType, float>::value || if constexpr(std::is_same<CDataType, float>::value ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment