Commit 5872e647 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[AMD] Support float8 matrix core (#537)



* [Enhancement] Add support for FP8 types in CUDA and HIP code generation

* Updated `GetFP8Type` function in `codegen_cuda.cc` and `codegen_hip.cc` to handle new FP8 types, including `kFloat8_e4m3fnuz`.
* Introduced a new header file `hip_fp8.h` for FP8 type definitions in HIP.
* Modified type mappings in `dlpack.py` and `mfma_macro_generator.py` to accommodate new FP8 types.
* Enhanced type handling in `TLHIPSourceWrapper` and `tensor.py` for better integration with FP8 types.
* Added necessary includes and logic to support FP8 in the code generation process, improving performance and compatibility with FP8 data types.

* lint fix

* Update src/target/codegen_hip.cc
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update tilelang/intrinsics/mfma_macro_generator.py
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* workaround

* fix

* Update submodule TVM to latest commit 587028ffebfff0ded520f8f90d62f0f6b165906c

* bug fix

* Refactor tilelang matrix multiplication to support transposition and packing options. Adjusted shared memory shapes and loading logic for A and B matrices. Updated test cases to validate new functionality.

* Refactor assertion function for tilelang matrix multiplication to improve readability by formatting parameters and aligning code. Cleaned up whitespace in intrinsic layout functions for consistency.

* Update bfloat16 type definitions in common.h and gemm.h for consistency. Changed __hip_bfloat16 to hip_bfloat16 and updated MfmaTraits specialization accordingly.

* lint fix

---------
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 1940b3c9
......@@ -58,7 +58,7 @@ def flashmla_decode(batch,
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
for k in T.Pipelined(loop_range, num_stages=0):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
......
......@@ -41,6 +41,8 @@ static std::string GetFP8Type(DataType type) {
}
if (type.code() == DataType::kFloat8_e4m3fn) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e4m3fnuz) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e5m2) {
stream << "fp8_e5" << vec << "_t";
} else {
......
......@@ -20,6 +20,36 @@
namespace tvm {
namespace codegen {
static std::string GetFP8Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "_2";
} else if (lanes == 4) {
vec = "_4";
} else if (lanes == 8) {
vec = "_8";
} else if (lanes == 16) {
vec = "_16";
} else {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) "
"for FP8";
}
if (type.code() == DataType::kFloat8_e4m3fn) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e4m3fnuz) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e5m2) {
stream << "fp8_e5" << vec << "_t";
} else {
LOG(FATAL) << "Unsupported FP8 type in HIP codegen";
}
return stream.str();
}
/*!
* \brief Replace patterns with replacement strings.
* \note should use std::format instead when codebase is ported to C++20.
......@@ -104,6 +134,11 @@ std::string CodeGenTileLangHIP::Finish() {
if (need_mma_h_) {
decl_stream << "#include <mma.h>\n";
}
if (enable_fp8_) {
decl_stream << "#include <tl_templates/hip/hip_fp8.h>\n";
}
decl_stream << "#include <tl_templates/hip/gemm.h>\n";
decl_stream << "#include <tl_templates/hip/copy.h>\n";
decl_stream << "#include <tl_templates/hip/reduce.h>\n";
......@@ -226,17 +261,8 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
if (!fail)
return;
} else if (t.is_float8()) {
if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of
// unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail)
enable_fp8_ = true;
os << GetFP8Type(t);
return;
} else if (t == DataType::Bool()) {
os << "bool";
......@@ -898,6 +924,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"float16x4", "float16x4"},
{"bfloat16x4", "bfloat16x4"},
{"float32x4", "float32x4"},
{"float8_e4m3fnuzx4", "fp8_e4_4_t"},
{"float8_e4m3fnuzx8", "long"},
{"float32x16", "float32x16"}};
std::string call_mfma_code = R"({
*((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}),
......@@ -906,6 +934,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
})";
std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
Replacer replacer;
replacer.register_rule("{mfma_buildin}", mfma_buildin);
replacer.register_rule("{A_dytpe}", dtype_map[A_dtype]);
replacer.register_rule("{B_dytpe}", dtype_map[B_dtype]);
......
......@@ -75,6 +75,8 @@ private:
bool need_math_constants_h_{false};
// whether need mfma.h
bool need_wmma_h_{false};
// whether need fp8.h
bool enable_fp8_{false};
// The size of the barrier array in shared memory
int barrier_count_ = -1;
// whether need mma.h
......
......@@ -64,7 +64,7 @@ using float16x16 =
using half_t = float16_t;
using bfloat16_t = __hip_bfloat16;
using bfloat16_t = hip_bfloat16;
struct bfloat16x2 {
bfloat16_t data[2];
......
......@@ -17,16 +17,16 @@ template <> struct MfmaTraits<half> {
}
};
// Specialization for __hip_bfloat16
template <> struct MfmaTraits<__hip_bfloat16> {
// Specialization for bfloat16_t
template <> struct MfmaTraits<bfloat16_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const __hip_bfloat16 *b,
const __hip_bfloat16 *a, AccType *c) {
static TL_DEVICE void mfma_op(const bfloat16_t *b, const bfloat16_t *a,
AccType *c) {
bfloat16x4_vec b_vec, a_vec;
// Reinterpret the pointers
short *b_short = reinterpret_cast<short *>(const_cast<__hip_bfloat16 *>(b));
short *a_short = reinterpret_cast<short *>(const_cast<__hip_bfloat16 *>(a));
short *b_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(b));
short *a_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(a));
// Copy the data
for (int i = 0; i < 4; ++i) {
......
#include <hip/amd_detail/amd_hip_fp8.h>
using fp8_e4_t = __hip_fp8_e4m3_fnuz;
using fp8_e4_2_t = __hip_fp8x2_e4m3_fnuz;
using fp8_e4_4_t = __hip_fp8x4_e4m3_fnuz;
struct __align__(8) fp8_e4_8_t {
fp8_e4_4_t x;
fp8_e4_4_t y;
};
struct __align__(16) fp8_e4_16_t {
fp8_e4_8_t x;
fp8_e4_8_t y;
};
......@@ -19,6 +19,9 @@ def tl_matmul(
in_dtype,
out_dtype,
accum_dtype,
a_transposed=False,
b_transposed=True,
k_pack=1,
):
assert in_dtype in [
"float16",
......@@ -32,27 +35,26 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
if in_dtype in {"float8_e4m3fnuz", "int8"}:
micro_size_k = 32
block_row_warps = 1
block_col_warps = 1
warp_row_tiles = 16
warp_col_tiles = 16
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32
shared_scope = "shared.dyn"
shared_scope = "shared"
cache_write_shared = False
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
A_shape = (K, M) if a_transposed else (M, K)
B_shape = (N, K) if b_transposed else (K, N)
A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
......@@ -60,8 +62,8 @@ def tl_matmul(
warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size
local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
......@@ -71,13 +73,14 @@ def tl_matmul(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
a_transposed=a_transposed,
b_transposed=b_transposed,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
k_pack=k_pack,
)
@T.prim_func
......@@ -108,14 +111,18 @@ def tl_matmul(
for ko in T.Pipelined((K // block_K), num_stages=0):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
if a_transposed:
T.copy(A[ko * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, ko * block_K], A_shared)
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
if b_transposed:
T.copy(B[bx * block_N, ko * block_K], B_shared)
else:
T.copy(B[ko * block_K, bx * block_N], B_shared)
for ki in T.serial(0, (block_K // micro_size_k)):
for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):
# Load A into fragment
mfma_emitter.ldmatrix_a(
......@@ -160,20 +167,30 @@ def tl_matmul(
return main
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32"):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
def assert_tl_matmul_correctness(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype="float32",
a_transposed=False,
b_transposed=True,
k_pack=1):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
k_pack)
print(matmul)
kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
A_shape = (K, M) if a_transposed else (M, K)
B_shape = (N, K) if b_transposed else (K, N)
if in_dtype == "int8":
A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8)
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
......@@ -185,8 +202,22 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa
# Ensure that the latency is not None
assert latency is not None
if a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.T.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
ref_c = torch.matmul(A.Tto(torch.float32),
B.to(torch.float32)).to(getattr(torch, out_dtype))
elif not a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
else:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
......@@ -196,6 +227,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2)
if __name__ == "__main__":
......
......@@ -38,7 +38,7 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
float8_dtype_map = {
torch.float8_e4m3fn: "e4m3_float8",
torch.float8_e4m3fnuz: "e4m3_float8",
torch.float8_e4m3fnuz: "float8_e4m3fnuz",
torch.float8_e5m2: "e5m2_float8",
torch.float8_e5m2fnuz: "e5m2_float8",
}
......
......@@ -61,6 +61,12 @@ def shared_16x16_to_local_64x4_layout_B(i, j):
return thread_id, local
shared_16x16_to_local_64x4_layout_m_n = shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k = shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_m = shared_16x16_to_local_64x4_layout_B
shared_16x16_to_local_64x4_layout_k_n = shared_16x16_to_local_64x4_layout_B
def thread_id_shared_access_64x4_to_16x16_layout_C_m_n(thread_id, local_id):
i = local_id + (thread_id // 16) * 4
j = thread_id % 16
......@@ -97,6 +103,30 @@ def shared_16x32_to_local_64x8_layout_B(i, j):
return thread_id, local
def thread_id_shared_access_64x16_to_16x64_layout_A(thread_id, local_id):
i = thread_id % 16
j = local_id + (thread_id // 16) * 16
return i, j
def shared_16x64_to_local_64x16_layout_A(i, j):
thread_id = i + 16 * (j // 16)
local = (j % 16)
return thread_id, local
def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id):
i = local_id + (thread_id // 16) * 16
j = thread_id % 16
return i, j
def shared_16x64_to_local_64x16_layout_B(i, j):
thread_id = i + 16 * (j // 16)
local = (j % 16)
return thread_id, local
def make_mfma_swizzle_layout(shared_buf, vecSize=8):
dtype = shared_buf.dtype
shape = shared_buf.shape
......
......@@ -27,6 +27,7 @@ class MatrixCoreIntrinEmitter(object):
"int32": "int32",
"e4m3_float8": "e4m3",
"e5m2_float8": "e5m2",
"float8_e4m3fnuz": "e4m3fnuz",
}
# k_pack represents the number of elements in a vectorized instruction
......@@ -80,10 +81,14 @@ class MatrixCoreIntrinEmitter(object):
def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str):
if a_dtype in ["float8_e4m3fnuz"]:
self.k_dim = 32
return
a_dtype = DataType(a_dtype)
if a_dtype.bits == 32:
self.k_dim = 4
elif a_dtype.bits in [16, 8]:
elif a_dtype.bits in {16, 8}:
self.k_dim = 16
else:
raise ValueError(f"Unsupported a_dtype = {a_dtype}")
......@@ -112,9 +117,13 @@ class MatrixCoreIntrinEmitter(object):
"float16": "f16",
"float32": "f32",
"int8": "i8",
"int32": "i32"
"int32": "i32",
"float8_e4m3fnuz": "fp8",
}[in_dtype]
if in_dtype_abbrv == "fp8":
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8"
else:
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
......@@ -138,12 +147,16 @@ class MatrixCoreIntrinEmitter(object):
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
shared_16x64_to_local_64x16_layout_A,
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
thread_id_shared_access_64x16_to_16x64_layout_A,
thread_id_shared_access_64x16_to_16x64_layout_B,
)
k_dim = self.k_dim * self.k_pack
......@@ -168,8 +181,15 @@ class MatrixCoreIntrinEmitter(object):
if is_b:
index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B
elif k_dim == 64:
index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A
reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A
if is_b:
index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B
reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B
else:
raise ValueError("k_dim must be 4 or 16 currently")
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
return index_map, reverse_index_map
......@@ -228,7 +248,7 @@ class MatrixCoreIntrinEmitter(object):
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * micro_size_k,
l, r = (rk * chunk + ki * (k_pack * micro_size_k),
warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]
......@@ -237,7 +257,7 @@ class MatrixCoreIntrinEmitter(object):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k)
rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]
......@@ -272,7 +292,7 @@ class MatrixCoreIntrinEmitter(object):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * micro_size_k,
rk * chunk + ki * (k_pack * micro_size_k),
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
......@@ -281,7 +301,7 @@ class MatrixCoreIntrinEmitter(object):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * chunk + ki * micro_size_k,
rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
......@@ -344,7 +364,7 @@ class MatrixCoreIntrinEmitter(object):
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out):
for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id))
C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row,
col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
......@@ -354,7 +374,7 @@ class MatrixCoreIntrinEmitter(object):
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out):
for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id))
C_buf[(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row,
(pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM +
......
......@@ -475,6 +475,26 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
A wrapper class for the TileLang HIP backend.
"""
_TYPE_MAP = {
"float32": "float",
"float16": "half_t",
"bfloat16": "bfloat16_t",
"e4m3_float8": "fp8_e4_t",
"e5m2_float8": "fp8_e5_t",
"float8_e4m3fnuz": "fp8_e4_t",
"e4m3fnuz_float8": "fp8_e4_t",
"float64": "double",
"int64": "int64_t",
"int32": "int",
"uint32": "unsigned int",
"bool": "int8_t",
"int8": "int8_t",
"uint8": "uint8_t",
"int16": "int16_t",
"uint16": "uint16_t",
"uchar": "uint8_t",
}
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
......
......@@ -22,6 +22,7 @@ def map_torch_type(intype: str) -> torch.dtype:
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
'e4m3fnuz_float8': torch.float8_e4m3fnuz,
}
if intype in typemap:
return typemap[intype]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment