Commit 5872e647 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[AMD] Support float8 matrix core (#537)



* [Enhancement] Add support for FP8 types in CUDA and HIP code generation

* Updated `GetFP8Type` function in `codegen_cuda.cc` and `codegen_hip.cc` to handle new FP8 types, including `kFloat8_e4m3fnuz`.
* Introduced a new header file `hip_fp8.h` for FP8 type definitions in HIP.
* Modified type mappings in `dlpack.py` and `mfma_macro_generator.py` to accommodate new FP8 types.
* Enhanced type handling in `TLHIPSourceWrapper` and `tensor.py` for better integration with FP8 types.
* Added necessary includes and logic to support FP8 in the code generation process, improving performance and compatibility with FP8 data types.

* lint fix

* Update src/target/codegen_hip.cc
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update tilelang/intrinsics/mfma_macro_generator.py
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* workaround

* fix

* Update submodule TVM to latest commit 587028ffebfff0ded520f8f90d62f0f6b165906c

* bug fix

* Refactor tilelang matrix multiplication to support transposition and packing options. Adjusted shared memory shapes and loading logic for A and B matrices. Updated test cases to validate new functionality.

* Refactor assertion function for tilelang matrix multiplication to improve readability by formatting parameters and aligning code. Cleaned up whitespace in intrinsic layout functions for consistency.

* Update bfloat16 type definitions in common.h and gemm.h for consistency. Changed __hip_bfloat16 to hip_bfloat16 and updated MfmaTraits specialization accordingly.

* lint fix

---------
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 1940b3c9
...@@ -58,7 +58,7 @@ def flashmla_decode(batch, ...@@ -58,7 +58,7 @@ def flashmla_decode(batch,
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N) loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=0):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s) T.clear(acc_s)
......
...@@ -41,6 +41,8 @@ static std::string GetFP8Type(DataType type) { ...@@ -41,6 +41,8 @@ static std::string GetFP8Type(DataType type) {
} }
if (type.code() == DataType::kFloat8_e4m3fn) { if (type.code() == DataType::kFloat8_e4m3fn) {
stream << "fp8_e4" << vec << "_t"; stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e4m3fnuz) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e5m2) { } else if (type.code() == DataType::kFloat8_e5m2) {
stream << "fp8_e5" << vec << "_t"; stream << "fp8_e5" << vec << "_t";
} else { } else {
......
...@@ -20,6 +20,36 @@ ...@@ -20,6 +20,36 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
static std::string GetFP8Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "_2";
} else if (lanes == 4) {
vec = "_4";
} else if (lanes == 8) {
vec = "_8";
} else if (lanes == 16) {
vec = "_16";
} else {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) "
"for FP8";
}
if (type.code() == DataType::kFloat8_e4m3fn) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e4m3fnuz) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e5m2) {
stream << "fp8_e5" << vec << "_t";
} else {
LOG(FATAL) << "Unsupported FP8 type in HIP codegen";
}
return stream.str();
}
/*! /*!
* \brief Replace patterns with replacement strings. * \brief Replace patterns with replacement strings.
* \note should use std::format instead when codebase is ported to C++20. * \note should use std::format instead when codebase is ported to C++20.
...@@ -104,6 +134,11 @@ std::string CodeGenTileLangHIP::Finish() { ...@@ -104,6 +134,11 @@ std::string CodeGenTileLangHIP::Finish() {
if (need_mma_h_) { if (need_mma_h_) {
decl_stream << "#include <mma.h>\n"; decl_stream << "#include <mma.h>\n";
} }
if (enable_fp8_) {
decl_stream << "#include <tl_templates/hip/hip_fp8.h>\n";
}
decl_stream << "#include <tl_templates/hip/gemm.h>\n"; decl_stream << "#include <tl_templates/hip/gemm.h>\n";
decl_stream << "#include <tl_templates/hip/copy.h>\n"; decl_stream << "#include <tl_templates/hip/copy.h>\n";
decl_stream << "#include <tl_templates/hip/reduce.h>\n"; decl_stream << "#include <tl_templates/hip/reduce.h>\n";
...@@ -226,18 +261,9 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -226,18 +261,9 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
if (!fail) if (!fail)
return; return;
} else if (t.is_float8()) { } else if (t.is_float8()) {
if (t.is_scalar()) { enable_fp8_ = true;
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char os << GetFP8Type(t);
} else if (lanes == 2) { return;
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of
// unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail)
return;
} else if (t == DataType::Bool()) { } else if (t == DataType::Bool()) {
os << "bool"; os << "bool";
return; return;
...@@ -898,6 +924,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -898,6 +924,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"float16x4", "float16x4"}, {"float16x4", "float16x4"},
{"bfloat16x4", "bfloat16x4"}, {"bfloat16x4", "bfloat16x4"},
{"float32x4", "float32x4"}, {"float32x4", "float32x4"},
{"float8_e4m3fnuzx4", "fp8_e4_4_t"},
{"float8_e4m3fnuzx8", "long"},
{"float32x16", "float32x16"}}; {"float32x16", "float32x16"}};
std::string call_mfma_code = R"({ std::string call_mfma_code = R"({
*((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}), *((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}),
...@@ -906,6 +934,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -906,6 +934,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
})"; })";
std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix; std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
Replacer replacer; Replacer replacer;
replacer.register_rule("{mfma_buildin}", mfma_buildin); replacer.register_rule("{mfma_buildin}", mfma_buildin);
replacer.register_rule("{A_dytpe}", dtype_map[A_dtype]); replacer.register_rule("{A_dytpe}", dtype_map[A_dtype]);
replacer.register_rule("{B_dytpe}", dtype_map[B_dtype]); replacer.register_rule("{B_dytpe}", dtype_map[B_dtype]);
......
...@@ -75,6 +75,8 @@ private: ...@@ -75,6 +75,8 @@ private:
bool need_math_constants_h_{false}; bool need_math_constants_h_{false};
// whether need mfma.h // whether need mfma.h
bool need_wmma_h_{false}; bool need_wmma_h_{false};
// whether need fp8.h
bool enable_fp8_{false};
// The size of the barrier array in shared memory // The size of the barrier array in shared memory
int barrier_count_ = -1; int barrier_count_ = -1;
// whether need mma.h // whether need mma.h
......
...@@ -64,7 +64,7 @@ using float16x16 = ...@@ -64,7 +64,7 @@ using float16x16 =
using half_t = float16_t; using half_t = float16_t;
using bfloat16_t = __hip_bfloat16; using bfloat16_t = hip_bfloat16;
struct bfloat16x2 { struct bfloat16x2 {
bfloat16_t data[2]; bfloat16_t data[2];
......
...@@ -17,16 +17,16 @@ template <> struct MfmaTraits<half> { ...@@ -17,16 +17,16 @@ template <> struct MfmaTraits<half> {
} }
}; };
// Specialization for __hip_bfloat16 // Specialization for bfloat16_t
template <> struct MfmaTraits<__hip_bfloat16> { template <> struct MfmaTraits<bfloat16_t> {
template <typename AccType> template <typename AccType>
static TL_DEVICE void mfma_op(const __hip_bfloat16 *b, static TL_DEVICE void mfma_op(const bfloat16_t *b, const bfloat16_t *a,
const __hip_bfloat16 *a, AccType *c) { AccType *c) {
bfloat16x4_vec b_vec, a_vec; bfloat16x4_vec b_vec, a_vec;
// Reinterpret the pointers // Reinterpret the pointers
short *b_short = reinterpret_cast<short *>(const_cast<__hip_bfloat16 *>(b)); short *b_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(b));
short *a_short = reinterpret_cast<short *>(const_cast<__hip_bfloat16 *>(a)); short *a_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(a));
// Copy the data // Copy the data
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
......
#include <hip/amd_detail/amd_hip_fp8.h>
using fp8_e4_t = __hip_fp8_e4m3_fnuz;
using fp8_e4_2_t = __hip_fp8x2_e4m3_fnuz;
using fp8_e4_4_t = __hip_fp8x4_e4m3_fnuz;
struct __align__(8) fp8_e4_8_t {
fp8_e4_4_t x;
fp8_e4_4_t y;
};
struct __align__(16) fp8_e4_16_t {
fp8_e4_8_t x;
fp8_e4_8_t y;
};
...@@ -19,6 +19,9 @@ def tl_matmul( ...@@ -19,6 +19,9 @@ def tl_matmul(
in_dtype, in_dtype,
out_dtype, out_dtype,
accum_dtype, accum_dtype,
a_transposed=False,
b_transposed=True,
k_pack=1,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", "float16",
...@@ -32,27 +35,26 @@ def tl_matmul( ...@@ -32,27 +35,26 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32": if in_dtype in {"float8_e4m3fnuz", "int8"}:
micro_size_k = 32 micro_size_k = 32
block_row_warps = 1 block_row_warps = 2
block_col_warps = 1 block_col_warps = 2
warp_row_tiles = 16 warp_row_tiles = 32
warp_col_tiles = 16 warp_col_tiles = 32
chunk = 32 chunk = 32
shared_scope = "shared.dyn" shared_scope = "shared"
cache_write_shared = False cache_write_shared = False
block_M = block_row_warps * warp_row_tiles block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles block_N = block_col_warps * warp_col_tiles
block_K = chunk block_K = chunk
A_shape = (M, K) A_shape = (K, M) if a_transposed else (M, K)
B_shape = (N, K) B_shape = (N, K) if b_transposed else (K, N)
A_shared_shape = (block_M, block_K) A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
B_shared_shape = (block_N, block_K) B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
C_shared_shape = ( C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y, block_N // micro_size_y,
micro_size_x, micro_size_x,
micro_size_y, micro_size_y,
...@@ -60,8 +62,8 @@ def tl_matmul( ...@@ -60,8 +62,8 @@ def tl_matmul(
warp_size = 64 warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps) threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y warp_cols = warp_col_tiles // micro_size_y
...@@ -71,13 +73,14 @@ def tl_matmul( ...@@ -71,13 +73,14 @@ def tl_matmul(
a_dtype=in_dtype, a_dtype=in_dtype,
b_dtype=in_dtype, b_dtype=in_dtype,
accum_dtype=accum_dtype, accum_dtype=accum_dtype,
a_transposed=False, a_transposed=a_transposed,
b_transposed=True, b_transposed=b_transposed,
block_row_warps=block_row_warps, block_row_warps=block_row_warps,
block_col_warps=block_col_warps, block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles, warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles, warp_col_tiles=warp_col_tiles,
chunk=chunk, chunk=chunk,
k_pack=k_pack,
) )
@T.prim_func @T.prim_func
...@@ -108,14 +111,18 @@ def tl_matmul( ...@@ -108,14 +111,18 @@ def tl_matmul(
for ko in T.Pipelined((K // block_K), num_stages=0): for ko in T.Pipelined((K // block_K), num_stages=0):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): if a_transposed:
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] T.copy(A[ko * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, ko * block_K], A_shared)
# Load B into shared memory # Load B into shared memory
for j, k in T.Parallel(block_N, block_K): if b_transposed:
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] T.copy(B[bx * block_N, ko * block_K], B_shared)
else:
T.copy(B[ko * block_K, bx * block_N], B_shared)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):
# Load A into fragment # Load A into fragment
mfma_emitter.ldmatrix_a( mfma_emitter.ldmatrix_a(
...@@ -160,20 +167,30 @@ def tl_matmul( ...@@ -160,20 +167,30 @@ def tl_matmul(
return main return main
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32"): def assert_tl_matmul_correctness(M,
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) N,
K,
in_dtype,
out_dtype,
accum_dtype="float32",
a_transposed=False,
b_transposed=True,
k_pack=1):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
k_pack)
print(matmul)
kernel = tilelang.compile(matmul) kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source() src_code = kernel.get_kernel_source()
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
A_shape = (K, M) if a_transposed else (M, K)
B_shape = (N, K) if b_transposed else (K, N)
if in_dtype == "int8": if in_dtype == "int8":
A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8) A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8) B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
else: else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C) kernel(A, B, C)
...@@ -185,8 +202,22 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa ...@@ -185,8 +202,22 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
# Get Reference Result if a_transposed and b_transposed:
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) # Get Reference Result
ref_c = torch.matmul(A.T.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.Tto(torch.float32),
B.to(torch.float32)).to(getattr(torch, out_dtype))
elif not a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
else:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
print(C) print(C)
print(ref_c) print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
...@@ -196,6 +227,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa ...@@ -196,6 +227,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa
def test_assert_tl_matmul(): def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16") assert_tl_matmul_correctness(128, 128, 128, "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32") assert_tl_matmul_correctness(128, 256, 256, "float16", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -38,7 +38,7 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func): ...@@ -38,7 +38,7 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
float8_dtype_map = { float8_dtype_map = {
torch.float8_e4m3fn: "e4m3_float8", torch.float8_e4m3fn: "e4m3_float8",
torch.float8_e4m3fnuz: "e4m3_float8", torch.float8_e4m3fnuz: "float8_e4m3fnuz",
torch.float8_e5m2: "e5m2_float8", torch.float8_e5m2: "e5m2_float8",
torch.float8_e5m2fnuz: "e5m2_float8", torch.float8_e5m2fnuz: "e5m2_float8",
} }
......
...@@ -61,6 +61,12 @@ def shared_16x16_to_local_64x4_layout_B(i, j): ...@@ -61,6 +61,12 @@ def shared_16x16_to_local_64x4_layout_B(i, j):
return thread_id, local return thread_id, local
shared_16x16_to_local_64x4_layout_m_n = shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k = shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_m = shared_16x16_to_local_64x4_layout_B
shared_16x16_to_local_64x4_layout_k_n = shared_16x16_to_local_64x4_layout_B
def thread_id_shared_access_64x4_to_16x16_layout_C_m_n(thread_id, local_id): def thread_id_shared_access_64x4_to_16x16_layout_C_m_n(thread_id, local_id):
i = local_id + (thread_id // 16) * 4 i = local_id + (thread_id // 16) * 4
j = thread_id % 16 j = thread_id % 16
...@@ -97,6 +103,30 @@ def shared_16x32_to_local_64x8_layout_B(i, j): ...@@ -97,6 +103,30 @@ def shared_16x32_to_local_64x8_layout_B(i, j):
return thread_id, local return thread_id, local
def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id):
i = thread_id % 16
j = local_id + (thread_id // 16) * 16
return i, j
def shared_16x64_to_local_64x16_layout_A(i, j):
thread_id = i + 16 * (j // 16)
local = (j % 16)
return thread_id, local
def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id):
i = local_id + (thread_id // 16) * 16
j = thread_id % 16
return i, j
def shared_16x64_to_local_64x16_layout_B(i, j):
thread_id = i + 16 * (j // 16)
local = (j % 16)
return thread_id, local
def make_mfma_swizzle_layout(shared_buf, vecSize=8): def make_mfma_swizzle_layout(shared_buf, vecSize=8):
dtype = shared_buf.dtype dtype = shared_buf.dtype
shape = shared_buf.shape shape = shared_buf.shape
......
...@@ -27,6 +27,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -27,6 +27,7 @@ class MatrixCoreIntrinEmitter(object):
"int32": "int32", "int32": "int32",
"e4m3_float8": "e4m3", "e4m3_float8": "e4m3",
"e5m2_float8": "e5m2", "e5m2_float8": "e5m2",
"float8_e4m3fnuz": "e4m3fnuz",
} }
# k_pack represents the number of elements in a vectorized instruction # k_pack represents the number of elements in a vectorized instruction
...@@ -80,10 +81,14 @@ class MatrixCoreIntrinEmitter(object): ...@@ -80,10 +81,14 @@ class MatrixCoreIntrinEmitter(object):
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str): if isinstance(a_dtype, str):
if a_dtype in ["float8_e4m3fnuz"]:
self.k_dim = 32
return
a_dtype = DataType(a_dtype) a_dtype = DataType(a_dtype)
if a_dtype.bits == 32: if a_dtype.bits == 32:
self.k_dim = 4 self.k_dim = 4
elif a_dtype.bits in [16, 8]: elif a_dtype.bits in {16, 8}:
self.k_dim = 16 self.k_dim = 16
else: else:
raise ValueError(f"Unsupported a_dtype = {a_dtype}") raise ValueError(f"Unsupported a_dtype = {a_dtype}")
...@@ -112,10 +117,14 @@ class MatrixCoreIntrinEmitter(object): ...@@ -112,10 +117,14 @@ class MatrixCoreIntrinEmitter(object):
"float16": "f16", "float16": "f16",
"float32": "f32", "float32": "f32",
"int8": "i8", "int8": "i8",
"int32": "i32" "int32": "i32",
"float8_e4m3fnuz": "fp8",
}[in_dtype] }[in_dtype]
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}" if in_dtype_abbrv == "fp8":
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8"
else:
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
self.micro_size_x = m_dim self.micro_size_x = m_dim
...@@ -138,12 +147,16 @@ class MatrixCoreIntrinEmitter(object): ...@@ -138,12 +147,16 @@ class MatrixCoreIntrinEmitter(object):
shared_16x16_to_local_64x4_layout_B, shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A, shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B, shared_16x32_to_local_64x8_layout_B,
shared_16x64_to_local_64x16_layout_A,
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A, thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B, thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A, thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B, thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A, thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B, thread_id_shared_access_64x8_to_16x32_layout_B,
thread_id_shared_access_64x16_to_16x64_layout_A,
thread_id_shared_access_64x16_to_16x64_layout_B,
) )
k_dim = self.k_dim * self.k_pack k_dim = self.k_dim * self.k_pack
...@@ -168,8 +181,15 @@ class MatrixCoreIntrinEmitter(object): ...@@ -168,8 +181,15 @@ class MatrixCoreIntrinEmitter(object):
if is_b: if is_b:
index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B
elif k_dim == 64:
index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A
reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A
if is_b:
index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B
reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B
else: else:
raise ValueError("k_dim must be 4 or 16 currently") raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
return index_map, reverse_index_map return index_map, reverse_index_map
...@@ -228,7 +248,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -228,7 +248,7 @@ class MatrixCoreIntrinEmitter(object):
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * micro_size_k, l, r = (rk * chunk + ki * (k_pack * micro_size_k),
warp_m * warp_row_tiles + i * micro_size_x) warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col] r + col]
...@@ -237,7 +257,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -237,7 +257,7 @@ class MatrixCoreIntrinEmitter(object):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x, l, r = (warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k) rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col] r + col]
...@@ -272,7 +292,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -272,7 +292,7 @@ class MatrixCoreIntrinEmitter(object):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = ( l, r = (
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * micro_size_k, rk * chunk + ki * (k_pack * micro_size_k),
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col] r + col]
...@@ -281,7 +301,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -281,7 +301,7 @@ class MatrixCoreIntrinEmitter(object):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = ( l, r = (
rk * chunk + ki * micro_size_k, rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
...@@ -344,7 +364,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -344,7 +364,7 @@ class MatrixCoreIntrinEmitter(object):
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding): def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols): for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out): for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mfma_store_index_map(tx, local_id))
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
...@@ -354,7 +374,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -354,7 +374,7 @@ class MatrixCoreIntrinEmitter(object):
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols): for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out): for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mfma_store_index_map(tx, local_id))
C_buf[(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, C_buf[(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM +
......
...@@ -475,6 +475,26 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): ...@@ -475,6 +475,26 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
A wrapper class for the TileLang HIP backend. A wrapper class for the TileLang HIP backend.
""" """
_TYPE_MAP = {
"float32": "float",
"float16": "half_t",
"bfloat16": "bfloat16_t",
"e4m3_float8": "fp8_e4_t",
"e5m2_float8": "fp8_e5_t",
"float8_e4m3fnuz": "fp8_e4_t",
"e4m3fnuz_float8": "fp8_e4_t",
"float64": "double",
"int64": "int64_t",
"int32": "int",
"uint32": "unsigned int",
"bool": "int8_t",
"int8": "int8_t",
"uint8": "uint8_t",
"int16": "int16_t",
"uint16": "uint16_t",
"uchar": "uint8_t",
}
def __init__(self, def __init__(self,
scheduled_ir_module: IRModule, scheduled_ir_module: IRModule,
source: str, source: str,
......
...@@ -22,6 +22,7 @@ def map_torch_type(intype: str) -> torch.dtype: ...@@ -22,6 +22,7 @@ def map_torch_type(intype: str) -> torch.dtype:
typemap = { typemap = {
'e4m3_float8': torch.float8_e4m3fn, 'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2, 'e5m2_float8': torch.float8_e5m2,
'e4m3fnuz_float8': torch.float8_e4m3fnuz,
} }
if intype in typemap: if intype in typemap:
return typemap[intype] return typemap[intype]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment