Unverified Commit f0c721a4 authored by Yunqian Fan's avatar Yunqian Fan Committed by GitHub
Browse files

[Enhancement] add more dtype and fix mma.ws for fp16 for tcgen05 (#1327)

* feat: add fp8 variants; add placeholder for fp6/fp4 in meta

support ld with pack for fp32 dtype

add dump

add tempalte expand

remove unused dtype and change to rebased apis

* fix: when atom-m!=128, enable_ws

* fix: typo in tcgen05 meta; dispatch in gemm sm100
parent f5d9da46
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm_v2(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=(k == 0),
)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 64, 256, 32
trans_A, trans_B = False, True
num_stages = 2
threads = 256
for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]:
for tvm_acc_dtype in ["float16", "float32"]: # , torch.float16]:
torch_fp8_dtype = map_torch_type(tvm_fp8_dtype)
torch_acc_dtype = map_torch_type(tvm_acc_dtype)
print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}")
in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype
func = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True,
},
)
# jit_kernel.export_ptx("./dump.ptx")
# jit_kernel.export_sources("./dump.cu")
a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype)
b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype)
c = jit_kernel(a, b)
ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float()
c = c.float()
diff = calc_diff(c, ref_c)
# assert diff < 1e-3, f"{diff}"
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}")
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms")
print(
f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS"
)
...@@ -1118,6 +1118,11 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, ...@@ -1118,6 +1118,11 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
bool is_ld = false; // tcgen05.ld (tensor memory -> register) bool is_ld = false; // tcgen05.ld (tensor memory -> register)
bool is_st = false; // tcgen05.st (register -> tensor memory) bool is_st = false; // tcgen05.st (register -> tensor memory)
bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory) bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory)
bool src_needs_pack =
16 == src->dtype.bits(); // if needs .pack::16b when is_ld
bool dst_needs_unpack =
16 == dst->dtype.bits(); // if needs .unpack::16b when is_st
if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") { if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") {
is_ld = true; is_ld = true;
} else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") { } else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") {
...@@ -1125,9 +1130,8 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, ...@@ -1125,9 +1130,8 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
} else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") { } else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") {
is_cp = true; is_cp = true;
} else { } else {
ICHECK(0) << "Unsupported tensor memory copy: " ICHECK(0) << "Unsupported tensor memory copy: " << "src scope = "
<< "src scope = " << src.scope() << src.scope() << ", dst scope = " << dst.scope();
<< ", dst scope = " << dst.scope();
} }
// Currently tcgen05.cp is not supported // Currently tcgen05.cp is not supported
// TODO (mzw) Support tcgen05.cp // TODO (mzw) Support tcgen05.cp
...@@ -1247,8 +1251,10 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, ...@@ -1247,8 +1251,10 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
: relative_wg_idx * (num_chunks_each_wg * meta.width); : relative_wg_idx * (num_chunks_each_wg * meta.width);
have_succeeded = true; have_succeeded = true;
Array<PrimExpr> args; Array<PrimExpr> args;
const char *bool_str = src_needs_pack ? "true" : "false";
args.push_back(StringImm(meta.intrinsics_name + "<" + args.push_back(StringImm(meta.intrinsics_name + "<" +
std::to_string(num_chunks_each_wg) + ">")); std::to_string(num_chunks_each_wg) + ", " +
bool_str + ">"));
args.push_back( args.push_back(
BufferLoad(src, {(int)logical_row_min, BufferLoad(src, {(int)logical_row_min,
(int)logical_col_min})); // Will be translated later (int)logical_col_min})); // Will be translated later
......
...@@ -344,6 +344,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { ...@@ -344,6 +344,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
result.push_back(Integer(meta.atom_m)); result.push_back(Integer(meta.atom_m));
result.push_back(Integer(meta.atom_n)); result.push_back(Integer(meta.atom_n));
result.push_back(Integer(meta.atom_k)); result.push_back(Integer(meta.atom_k));
result.push_back(Integer(meta.enable_ws));
result.push_back(Integer(meta.enable_2cta));
} }
return result; return result;
}); });
......
...@@ -15,16 +15,19 @@ using runtime::DataType; ...@@ -15,16 +15,19 @@ using runtime::DataType;
struct TCGEN5MMAMeta { struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k; int atom_m, atom_n, atom_k;
bool enable_ws, enable_2cta;
}; };
inline std::pair<bool, TCGEN5MMAMeta> inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. // TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \ #define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \ return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ false, TCGEN5MMAMeta { 0, 0, 0, false, false } \
}
#define SUCCESS(atom_m, atom_n, atom_k, use_ws, use_2cta) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k, use_ws, use_2cta } \
} }
std::vector<int> ws_valid_atom_ns = {256, 128, 64}; std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
...@@ -34,39 +37,52 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { ...@@ -34,39 +37,52 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
if (M % 128 == 0) { if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16) for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(128, atom_n, 16); SUCCESS(128, atom_n, 16, false, false);
FAIL; FAIL;
} else if (M % 64 == 0) { } else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns) for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(64, atom_n, 16); SUCCESS(64, atom_n, 16, true, false);
FAIL; FAIL;
} else if (M % 32 == 0) { } else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns) for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(32, atom_n, 16); SUCCESS(32, atom_n, 16, true, false);
FAIL; FAIL;
} else { } else {
FAIL; FAIL;
} }
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() ||
(c_dtype.is_float() && c_dtype.bits() == 32)) { ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() ||
ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() ||
ab_dtype.is_float4_e2m1fn()) &&
((c_dtype.is_float() && c_dtype.bits() == 32) ||
(c_dtype.is_float16() && c_dtype.bits() == 16))) {
if (K % 32 != 0) if (K % 32 != 0)
FAIL; FAIL;
if (M % 128 == 0) { if (M % 128 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, true, false);
for (int atom_n = 256; atom_n >= 16; atom_n -= 16) for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(128, atom_n, 32); SUCCESS(128, atom_n, 32, false, true);
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, false, false);
FAIL; FAIL;
} else if (M % 64 == 0) { } else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns) for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(64, atom_n, 32); SUCCESS(64, atom_n, 32, true, false);
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32, false, false);
FAIL; FAIL;
} else if (M % 32 == 0) { } else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns) for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(32, atom_n, 32); SUCCESS(32, atom_n, 32, true, false);
FAIL; FAIL;
} else { } else {
FAIL; FAIL;
......
...@@ -51,6 +51,21 @@ __device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr, ...@@ -51,6 +51,21 @@ __device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr,
: :
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
} }
__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e5_32_t *ptr) {
ulonglong4 ret;
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_global_256(fp8_e5_32_t *ptr,
fp8_e5_32_t &val8) {
ulonglong4 &val = *((ulonglong4 *)&val8);
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
__device__ __forceinline__ unsigned long long __device__ __forceinline__ unsigned long long
pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z, pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z,
...@@ -95,38 +110,38 @@ __device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col, ...@@ -95,38 +110,38 @@ __device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col,
} }
} }
template <int N, typename dst_t> template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void __device__ __forceinline__ void
tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col, tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) { uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp32bNx, 7, N>(tmem_start_col + tmem_col_offset, tcgen05_ld_core<tl::tmem_ld_32dp32bNx<pack16>, 7, N>(
dst_ptr); tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load(); tl::fence_view_async_tmem_load();
} }
template <int N, typename dst_t> template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void __device__ __forceinline__ void
tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col, tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) { uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp64bNx, 7, N>(tmem_start_col + tmem_col_offset, tcgen05_ld_core<tl::tmem_ld_32dp64bNx<pack16>, 7, N>(
dst_ptr); tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load(); tl::fence_view_async_tmem_load();
} }
template <int N, typename dst_t> template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void __device__ __forceinline__ void
tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col, tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) { uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp128bNx, 6, N>( tcgen05_ld_core<tl::tmem_ld_32dp128bNx<pack16>, 6, N>(
tmem_start_col + tmem_col_offset, dst_ptr); tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load(); tl::fence_view_async_tmem_load();
} }
template <int N, typename dst_t> template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void __device__ __forceinline__ void
tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col, tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) { uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp256bNx, 5, N>( tcgen05_ld_core<tl::tmem_ld_32dp256bNx<pack16>, 5, N>(
tmem_start_col + tmem_col_offset, dst_ptr); tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load(); tl::fence_view_async_tmem_load();
} }
......
...@@ -243,47 +243,99 @@ struct DispatchInstruction<half_t, half_t, float, M, N, K, a_major, b_major, ...@@ -243,47 +243,99 @@ struct DispatchInstruction<half_t, half_t, float, M, N, K, a_major, b_major,
}; };
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major> template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, M, N, K, a_major, b_major, struct DispatchInstruction<cute::float_e4m3_t, cute::float_e4m3_t, float, M, N,
K, a_major, b_major,
std::enable_if_t<M == 128 && K == 32>> { std::enable_if_t<M == 128 && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_SS, fp8_e4_t, fp8_e4_t, float, Int<M>, using MMA =
Int<N>, integral_constant<UMMA::Major, a_major>, MMA_Traits<SM100_MMA_F8F6F4_SS, cute::float_e4m3_t, cute::float_e4m3_t,
integral_constant<UMMA::Major, b_major>, float, Int<M>, Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>, integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>; integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
}; };
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major> template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, M, N, K, a_major, b_major, struct DispatchInstruction<cute::float_e4m3_t, cute::float_e4m3_t, float, M, N,
K, a_major, b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 32>> { std::enable_if_t<(M == 64 || M == 32) && K == 32>> {
using MMA = using MMA =
MMA_Traits<SM100_MMA_F8F6F4_WS_SS, fp8_e4_t, fp8_e4_t, float, Int<M>, MMA_Traits<SM100_MMA_F8F6F4_WS_SS, cute::float_e4m3_t, cute::float_e4m3_t,
Int<N>, integral_constant<UMMA::Major, a_major>, float, Int<M>, Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>, integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>, integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>; integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
}; };
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major> template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, M, N, K, a_major, b_major, struct DispatchInstruction<cute::float_e4m3_t, cute::float_e4m3_t, half_t, M, N,
K, a_major, b_major,
std::enable_if_t<M == 128 && K == 32>> { std::enable_if_t<M == 128 && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_SS, fp8_e5_t, fp8_e5_t, float, Int<M>, using MMA = MMA_Traits<SM100_MMA_F8F6F4_SS, cute::float_e4m3_t,
Int<N>, integral_constant<UMMA::Major, a_major>, cute::float_e4m3_t, half_t, Int<M>, Int<N>,
integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>, integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>, integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>; integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
}; };
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<cute::float_e4m3_t, cute::float_e4m3_t, half_t, M, N,
K, a_major, b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_WS_SS, cute::float_e4m3_t,
cute::float_e4m3_t, half_t, Int<M>, Int<N>,
integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<cute::float_e5m2_t, cute::float_e5m2_t, float, M, N,
K, a_major, b_major,
std::enable_if_t<M == 128 && K == 32>> {
using MMA =
MMA_Traits<SM100_MMA_F8F6F4_SS, cute::float_e5m2_t, cute::float_e5m2_t,
float, Int<M>, Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major> template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, M, N, K, a_major, b_major, struct DispatchInstruction<cute::float_e5m2_t, cute::float_e5m2_t, float, M, N,
K, a_major, b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 32>> { std::enable_if_t<(M == 64 || M == 32) && K == 32>> {
using MMA = using MMA =
MMA_Traits<SM100_MMA_F8F6F4_WS_SS, fp8_e5_t, fp8_e5_t, float, Int<M>, MMA_Traits<SM100_MMA_F8F6F4_WS_SS, cute::float_e5m2_t, cute::float_e5m2_t,
Int<N>, integral_constant<UMMA::Major, a_major>, float, Int<M>, Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>, integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>, integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>; integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
}; };
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<cute::float_e5m2_t, cute::float_e5m2_t, half_t, M, N,
K, a_major, b_major,
std::enable_if_t<M == 128 && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_SS, cute::float_e5m2_t,
cute::float_e5m2_t, half_t, Int<M>, Int<N>,
integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<cute::float_e5m2_t, cute::float_e5m2_t, half_t, M, N,
K, a_major, b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_WS_SS, cute::float_e5m2_t,
cute::float_e5m2_t, half_t, Int<M>, Int<N>,
integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A, template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw, bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw> typename C_type_raw>
......
This diff is collapsed.
...@@ -47,7 +47,10 @@ class TensorCoreIntrinEmitter: ...@@ -47,7 +47,10 @@ class TensorCoreIntrinEmitter:
"int8": "int8", "int8": "int8",
"int32": "int32", "int32": "int32",
"float8_e4m3": "e4m3", "float8_e4m3": "e4m3",
"float8_e4m3fn": "e4m3",
"float8_e4m3fnuz": "e4m3",
"float8_e5m2": "e5m2", "float8_e5m2": "e5m2",
"float8_e5m2fnuz": "e5m2",
} }
# Represent the thread binding in the form of (tx, warp_n, warp_m) # Represent the thread binding in the form of (tx, warp_n, warp_m)
......
...@@ -169,12 +169,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -169,12 +169,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
accum_dtype_in_bits = DataType(accum_dtype).bits accum_dtype_in_bits = DataType(accum_dtype).bits
meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim) meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim)
if len(meta) != 3: if len(meta) != 5:
raise ValueError( raise ValueError(
f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, "
f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
atom_m, atom_n, atom_k = (int(x) for x in meta) atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta)
enable_ws = atom_m != 128
# by default, we utilize non-swizzle layout offset # by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
...@@ -382,10 +381,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -382,10 +381,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
k = int(self.chunk) k = int(self.chunk)
meta = self.get_tcgen5_mma_meta(m, n, k) meta = self.get_tcgen5_mma_meta(m, n, k)
if len(meta) != 3: if len(meta) != 5:
raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, " raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, "
f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
atom_m, atom_n, _ = (int(x) for x in meta) atom_m, atom_n, _, _, _ = (int(x) for x in meta)
if m % atom_m != 0 or n % atom_n != 0: if m % atom_m != 0 or n % atom_n != 0:
raise ValueError( raise ValueError(
......
...@@ -144,6 +144,7 @@ class TLCUDASourceWrapper: ...@@ -144,6 +144,7 @@ class TLCUDASourceWrapper:
"float16": "half_t", "float16": "half_t",
"bfloat16": "bfloat16_t", "bfloat16": "bfloat16_t",
"float8_e4m3": "fp8_e4_t", "float8_e4m3": "fp8_e4_t",
"float8_e4m3fn": "fp8_e4_t",
"float8_e5m2": "fp8_e5_t", "float8_e5m2": "fp8_e5_t",
"float64": "double", "float64": "double",
"int64": "int64_t", "int64": "int64_t",
......
...@@ -85,6 +85,9 @@ class GemmTCGEN5(GemmBase): ...@@ -85,6 +85,9 @@ class GemmTCGEN5(GemmBase):
raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got " raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got "
f"A scope {self.A.scope()}, B scope {self.B.scope()}") f"A scope {self.A.scope()}, B scope {self.B.scope()}")
atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta(
self.M, self.N, self.K)
if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}: if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}:
raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}") raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}")
if self.B.scope() not in {"shared", "shared.dyn"}: if self.B.scope() not in {"shared", "shared.dyn"}:
...@@ -105,7 +108,7 @@ class GemmTCGEN5(GemmBase): ...@@ -105,7 +108,7 @@ class GemmTCGEN5(GemmBase):
raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access") raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access")
accum_dtype = str(self.C.dtype) accum_dtype = str(self.C.dtype)
if accum_dtype != "float32": if accum_dtype not in ["float32", 'float16']:
raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}")
A_shared = self.ARegion A_shared = self.ARegion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment