Commit 2add9fa3 authored by wangkx1's avatar wangkx1
Browse files

add tilelang

parent f5bc26c2
import tilelang
import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Clear local accumulation
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
# Copy tile of A
# This is a sugar syntax for parallelized copy
# for i, k in T.Parallel(M, block_K):
# A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
T.copy(A[by * block_M, ko * block_K], A_shared)
# Copy tile of B
T.copy(B[bx * block_N, ko * block_K], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
M = 128 # M = T.dynamic("m") if you want to use dynamic shape
N = 128
K = 32
block_M = 128
block_N = 128
block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(M, N, K, block_M, block_N, block_K)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(N, K, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b.T
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
import torch
import tilelang
import tilelang.language as T
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128
trans_A, trans_B = False, True
in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float
num_stages = 2
threads = 256
func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
print(jit_kernel.get_kernel_source())
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS")
import argparse
import tilelang
import tilelang.language as T
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import randn_semi_sparse
from tilelang.utils.tensor import torch_assert_close
from triton.testing import do_bench
import torch
torch.manual_seed(42)
DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 64,
"num_stages": 1,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
T.float16: {
"block_M": 256,
"block_N": 128,
"block_K": 64,
"num_stages": 2,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
},
"h20": {
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
T.float16: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
},
}
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16_custom_compress(
M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout
):
e_factor, e_dtype = (16, T.int16)
@T.prim_func
def gemm_sp_fp16_custom_compress(
A_sparse: T.Tensor((M, K // 2), T.float16),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), T.float16),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), T.float16)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), T.float16)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if use_cutlass_layout:
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K),
}
)
T.clear(C_local)
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return gemm_sp_fp16_custom_compress
def torch_compress(dense):
"""
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
"""
if dense.dim() != 2:
raise RuntimeError(f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor")
m, k = dense.shape
meta_dtype = torch.int8
if dense.dtype == torch.int8:
meta_dtype = torch.int32
elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
meta_dtype = torch.int16
else:
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
if quadbits_per_meta_elem not in (4, 8):
raise RuntimeError("Invalid number of elements per meta element calculated")
if meta_dtype == torch.int32:
if m % 16 != 0:
raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 16")
else:
if m % 32 != 0:
raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32")
if k % (4 * quadbits_per_meta_elem) != 0:
raise RuntimeError(f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}")
if dense.dtype != torch.float:
ksparse = 4
dense_4 = dense.view(-1, k // ksparse, ksparse)
m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1)
else:
ksparse = 2
dense_2 = dense.view(-1, k // ksparse, ksparse)
m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1)
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0 = m0 & m1
expr1 = ~m0 & m1
expr2 = ~m0 & ~m1
bit0 = expr1
bit1 = expr2
bit2 = expr0 | expr2 | m3
bit3 = expr1 | ~m1
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
if dense.dtype != torch.float:
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else:
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
if quadbits_per_meta_elem == 4:
meta = meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12)
elif quadbits_per_meta_elem == 8:
meta = (
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12)
| (meta_n[:, :, 4] << 16)
| (meta_n[:, :, 5] << 20)
| (meta_n[:, :, 6] << 24)
| (meta_n[:, :, 7] << 28)
)
return (sparse, meta)
def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
assert meta.dtype is torch.int16
groups_per_meta = 16 // 4 # 4 groups per uint16
out = []
for g in range(groups_per_meta):
group_bits = (meta >> (g * 4)) & 0xF
idx0 = group_bits & 0x3
idx1 = (group_bits >> 2) & 0x3
out.append(torch.stack([idx0, idx1], dim=-1))
return torch.concat(out, dim=-1).view(meta.shape[0], -1)
@tilelang.jit(
out_idx=[1, 2],
pass_configs={
tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
},
)
def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
e_factor, e_dtype = ARCH_INFO["8.0"]
e_K = K // e_factor
elem, group = 2, 4
assert M % block_M == 0, "M must be divisible by block_M"
assert K % block_K == 0, "K must be divisible by block_K"
assert K % e_factor == 0, "K must be divisible by e_factor"
assert block_K % e_factor == 0, "block_K must be divisible by e_factor"
@T.prim_func
def kernel(
A: T.Tensor((M, K), dtype),
A_sp: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, e_K), e_dtype),
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
if use_cutlass_layout:
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K),
}
)
T.clear(A_sp_shared)
T.clear(E_shared)
# TODO: alloc_var seems buggy here
non_zero_cnt = T.alloc_local((1,), dtype=T.uint8)
non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8)
T.copy(A[bx * block_M, by * block_K], A_shared)
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
a_k = g_i * group
non_zero_cnt[0] = 0
for i in range(elem):
non_zero_elt_log_idx[i] = 0
for i in range(group):
val = A_shared[tm, a_k + i]
if val != 0.0:
non_zero_elt_log_idx[non_zero_cnt[0]] = i
A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
non_zero_cnt[0] += 1
# TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main
if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
non_zero_elt_log_idx[0] = 0
non_zero_elt_log_idx[1] = 3
A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
A_sp_shared[tm, a_k // 2] = 0.0
elif non_zero_cnt[0] == 1:
A_sp_shared[tm, a_k // 2 + 1] = 0
non_zero_elt_log_idx[1] = 3
for i in T.serial(elem):
val = non_zero_elt_log_idx[i]
E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i)
T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
return kernel
def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor")
parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference")
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16_custom_compress(
args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype], use_cutlass_layout=args.use_cutlass_layout
)
a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half)
b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half)
if args.use_torch_compressor:
assert not args.use_cutlass_layout, "torch sparse must be used with naive layout"
a_sparse, e = torch_compress(a)
else:
a_sparse, e = compress_kernel(args.m, args.k, 32, 32, T.float16, use_cutlass_layout=args.use_cutlass_layout)(a)
c = kernel(a_sparse, e, b)
ref_c = a @ b
assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3)
print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}")
latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b)
total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s")
if __name__ == "__main__":
main()
import argparse
import tilelang
import tilelang.language as T
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import compress, randn_semi_sparse
from tilelang.contrib import nvcc
from triton.testing import do_bench
import torch
arch = nvcc.get_target_compute_version()
DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 64,
"num_stages": 1,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
T.float16: {
"block_M": 256,
"block_N": 128,
"block_K": 64,
"num_stages": 2,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
},
"h20": {
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
T.float16: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
},
}
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization):
e_factor, e_dtype = ARCH_INFO[arch]
@T.prim_func
def gemm_sp_fp16(
A_sparse: T.Tensor((M, K // 2), T.float16),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), T.float16),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), T.float16)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), T.float16)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, block_k=block_K, arch=arch),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, block_k=block_K, arch=arch),
}
)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return gemm_sp_fp16
def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype])
a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half)
b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half)
a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]["block_K"], arch=arch)
c = kernel(a_sparse, e, b)
ref_c = a @ b
assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2)
print(f"Precision check passed. diff: {(c - ref_c).abs().mean()}")
latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b)
total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s")
if __name__ == "__main__":
main()
import tilelang.testing
import example_custom_compress
import example_gemm_sp
def test_example_custom_compress():
example_custom_compress.main()
def test_example_gemm_sp():
example_gemm_sp.main()
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment