Unverified Commit f0c721a4 authored by Yunqian Fan's avatar Yunqian Fan Committed by GitHub
Browse files

[Enhancement] add more dtype and fix mma.ws for fp16 for tcgen05 (#1327)

* feat: add fp8 variants; add placeholder for fp6/fp4 in meta

support ld with pack for fp32 dtype

add dump

add tempalte expand

remove unused dtype and change to rebased apis

* fix: when atom-m!=128, enable_ws

* fix: typo in tcgen05 meta; dispatch in gemm sm100
parent f5d9da46
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm_v2(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=(k == 0),
)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 64, 256, 32
trans_A, trans_B = False, True
num_stages = 2
threads = 256
for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]:
for tvm_acc_dtype in ["float16", "float32"]: # , torch.float16]:
torch_fp8_dtype = map_torch_type(tvm_fp8_dtype)
torch_acc_dtype = map_torch_type(tvm_acc_dtype)
print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}")
in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype
func = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True,
},
)
# jit_kernel.export_ptx("./dump.ptx")
# jit_kernel.export_sources("./dump.cu")
a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype)
b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype)
c = jit_kernel(a, b)
ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float()
c = c.float()
diff = calc_diff(c, ref_c)
# assert diff < 1e-3, f"{diff}"
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}")
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms")
print(
f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS"
)
......@@ -1118,6 +1118,11 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
bool is_ld = false; // tcgen05.ld (tensor memory -> register)
bool is_st = false; // tcgen05.st (register -> tensor memory)
bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory)
bool src_needs_pack =
16 == src->dtype.bits(); // if needs .pack::16b when is_ld
bool dst_needs_unpack =
16 == dst->dtype.bits(); // if needs .unpack::16b when is_st
if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") {
is_ld = true;
} else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") {
......@@ -1125,9 +1130,8 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
} else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") {
is_cp = true;
} else {
ICHECK(0) << "Unsupported tensor memory copy: "
<< "src scope = " << src.scope()
<< ", dst scope = " << dst.scope();
ICHECK(0) << "Unsupported tensor memory copy: " << "src scope = "
<< src.scope() << ", dst scope = " << dst.scope();
}
// Currently tcgen05.cp is not supported
// TODO (mzw) Support tcgen05.cp
......@@ -1247,8 +1251,10 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
: relative_wg_idx * (num_chunks_each_wg * meta.width);
have_succeeded = true;
Array<PrimExpr> args;
const char *bool_str = src_needs_pack ? "true" : "false";
args.push_back(StringImm(meta.intrinsics_name + "<" +
std::to_string(num_chunks_each_wg) + ">"));
std::to_string(num_chunks_each_wg) + ", " +
bool_str + ">"));
args.push_back(
BufferLoad(src, {(int)logical_row_min,
(int)logical_col_min})); // Will be translated later
......
......@@ -344,6 +344,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
result.push_back(Integer(meta.atom_m));
result.push_back(Integer(meta.atom_n));
result.push_back(Integer(meta.atom_k));
result.push_back(Integer(meta.enable_ws));
result.push_back(Integer(meta.enable_2cta));
}
return result;
});
......
......@@ -15,16 +15,19 @@ using runtime::DataType;
struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
bool enable_ws, enable_2cta;
};
inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
false, TCGEN5MMAMeta { 0, 0, 0, false, false } \
}
#define SUCCESS(atom_m, atom_n, atom_k, use_ws, use_2cta) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k, use_ws, use_2cta } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
......@@ -34,39 +37,52 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16);
SUCCESS(128, atom_n, 16, false, false);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16);
SUCCESS(64, atom_n, 16, true, false);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16);
SUCCESS(32, atom_n, 16, true, false);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() ||
ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() ||
ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() ||
ab_dtype.is_float4_e2m1fn()) &&
((c_dtype.is_float() && c_dtype.bits() == 32) ||
(c_dtype.is_float16() && c_dtype.bits() == 16))) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, true, false);
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32);
SUCCESS(128, atom_n, 32, false, true);
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, false, false);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32);
SUCCESS(64, atom_n, 32, true, false);
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32, false, false);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32);
SUCCESS(32, atom_n, 32, true, false);
FAIL;
} else {
FAIL;
......
......@@ -51,6 +51,21 @@ __device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr,
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e5_32_t *ptr) {
ulonglong4 ret;
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_global_256(fp8_e5_32_t *ptr,
fp8_e5_32_t &val8) {
ulonglong4 &val = *((ulonglong4 *)&val8);
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
__device__ __forceinline__ unsigned long long
pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z,
......@@ -95,38 +110,38 @@ __device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col,
}
}
template <int N, typename dst_t>
template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp32bNx, 7, N>(tmem_start_col + tmem_col_offset,
dst_ptr);
tcgen05_ld_core<tl::tmem_ld_32dp32bNx<pack16>, 7, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}
template <int N, typename dst_t>
template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp64bNx, 7, N>(tmem_start_col + tmem_col_offset,
dst_ptr);
tcgen05_ld_core<tl::tmem_ld_32dp64bNx<pack16>, 7, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}
template <int N, typename dst_t>
template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp128bNx, 6, N>(
tcgen05_ld_core<tl::tmem_ld_32dp128bNx<pack16>, 6, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}
template <int N, typename dst_t>
template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp256bNx, 5, N>(
tcgen05_ld_core<tl::tmem_ld_32dp256bNx<pack16>, 5, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}
......
......@@ -243,42 +243,94 @@ struct DispatchInstruction<half_t, half_t, float, M, N, K, a_major, b_major,
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, M, N, K, a_major, b_major,
struct DispatchInstruction<cute::float_e4m3_t, cute::float_e4m3_t, float, M, N,
K, a_major, b_major,
std::enable_if_t<M == 128 && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_SS, fp8_e4_t, fp8_e4_t, float, Int<M>,
Int<N>, integral_constant<UMMA::Major, a_major>,
using MMA =
MMA_Traits<SM100_MMA_F8F6F4_SS, cute::float_e4m3_t, cute::float_e4m3_t,
float, Int<M>, Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, M, N, K, a_major, b_major,
struct DispatchInstruction<cute::float_e4m3_t, cute::float_e4m3_t, float, M, N,
K, a_major, b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 32>> {
using MMA =
MMA_Traits<SM100_MMA_F8F6F4_WS_SS, fp8_e4_t, fp8_e4_t, float, Int<M>,
Int<N>, integral_constant<UMMA::Major, a_major>,
MMA_Traits<SM100_MMA_F8F6F4_WS_SS, cute::float_e4m3_t, cute::float_e4m3_t,
float, Int<M>, Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<cute::float_e4m3_t, cute::float_e4m3_t, half_t, M, N,
K, a_major, b_major,
std::enable_if_t<M == 128 && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_SS, cute::float_e4m3_t,
cute::float_e4m3_t, half_t, Int<M>, Int<N>,
integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<cute::float_e4m3_t, cute::float_e4m3_t, half_t, M, N,
K, a_major, b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_WS_SS, cute::float_e4m3_t,
cute::float_e4m3_t, half_t, Int<M>, Int<N>,
integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, M, N, K, a_major, b_major,
struct DispatchInstruction<cute::float_e5m2_t, cute::float_e5m2_t, float, M, N,
K, a_major, b_major,
std::enable_if_t<M == 128 && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_SS, fp8_e5_t, fp8_e5_t, float, Int<M>,
Int<N>, integral_constant<UMMA::Major, a_major>,
using MMA =
MMA_Traits<SM100_MMA_F8F6F4_SS, cute::float_e5m2_t, cute::float_e5m2_t,
float, Int<M>, Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, M, N, K, a_major, b_major,
struct DispatchInstruction<cute::float_e5m2_t, cute::float_e5m2_t, float, M, N,
K, a_major, b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 32>> {
using MMA =
MMA_Traits<SM100_MMA_F8F6F4_WS_SS, fp8_e5_t, fp8_e5_t, float, Int<M>,
Int<N>, integral_constant<UMMA::Major, a_major>,
MMA_Traits<SM100_MMA_F8F6F4_WS_SS, cute::float_e5m2_t, cute::float_e5m2_t,
float, Int<M>, Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<cute::float_e5m2_t, cute::float_e5m2_t, half_t, M, N,
K, a_major, b_major,
std::enable_if_t<M == 128 && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_SS, cute::float_e5m2_t,
cute::float_e5m2_t, half_t, Int<M>, Int<N>,
integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<cute::float_e5m2_t, cute::float_e5m2_t, half_t, M, N,
K, a_major, b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_WS_SS, cute::float_e5m2_t,
cute::float_e5m2_t, half_t, Int<M>, Int<N>,
integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
......
This diff is collapsed.
......@@ -47,7 +47,10 @@ class TensorCoreIntrinEmitter:
"int8": "int8",
"int32": "int32",
"float8_e4m3": "e4m3",
"float8_e4m3fn": "e4m3",
"float8_e4m3fnuz": "e4m3",
"float8_e5m2": "e5m2",
"float8_e5m2fnuz": "e5m2",
}
# Represent the thread binding in the form of (tx, warp_n, warp_m)
......
......@@ -169,12 +169,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
accum_dtype_in_bits = DataType(accum_dtype).bits
meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim)
if len(meta) != 3:
if len(meta) != 5:
raise ValueError(
f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, "
f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
atom_m, atom_n, atom_k = (int(x) for x in meta)
enable_ws = atom_m != 128
atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta)
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
......@@ -382,10 +381,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
k = int(self.chunk)
meta = self.get_tcgen5_mma_meta(m, n, k)
if len(meta) != 3:
if len(meta) != 5:
raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, "
f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
atom_m, atom_n, _ = (int(x) for x in meta)
atom_m, atom_n, _, _, _ = (int(x) for x in meta)
if m % atom_m != 0 or n % atom_n != 0:
raise ValueError(
......
......@@ -144,6 +144,7 @@ class TLCUDASourceWrapper:
"float16": "half_t",
"bfloat16": "bfloat16_t",
"float8_e4m3": "fp8_e4_t",
"float8_e4m3fn": "fp8_e4_t",
"float8_e5m2": "fp8_e5_t",
"float64": "double",
"int64": "int64_t",
......
......@@ -85,6 +85,9 @@ class GemmTCGEN5(GemmBase):
raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got "
f"A scope {self.A.scope()}, B scope {self.B.scope()}")
atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta(
self.M, self.N, self.K)
if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}:
raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}")
if self.B.scope() not in {"shared", "shared.dyn"}:
......@@ -105,7 +108,7 @@ class GemmTCGEN5(GemmBase):
raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access")
accum_dtype = str(self.C.dtype)
if accum_dtype != "float32":
if accum_dtype not in ["float32", 'float16']:
raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}")
A_shared = self.ARegion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment