Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
......@@ -4,7 +4,8 @@ import tilelang
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
......@@ -34,18 +35,18 @@ def tl_matmul(
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
......@@ -53,7 +54,7 @@ def tl_matmul(
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
# chunk = 32 if in_dtype == "float16" else 64
# chunk = 32 if in_dtype == T.float16 else 64
chunk = 32
shared_scope = "shared.dyn"
......@@ -104,7 +105,6 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -112,10 +112,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -123,7 +125,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -133,7 +134,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
......@@ -163,7 +163,7 @@ def ref_program(A, B):
def main(M=4096, N=4096, K=4096):
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
in_dtype, out_dtype, accum_dtype = T.float16, T.float16, T.float32
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
......
......@@ -5,17 +5,7 @@ import argparse
@tilelang.jit(out_idx=[-1])
def matmul_non_persistent(M,
N,
K,
block_M,
block_N,
block_K,
threads,
num_stages,
dtype="float16",
accum_dtype="float"):
def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -43,18 +33,9 @@ def matmul_non_persistent(M,
@tilelang.jit(out_idx=[-1])
def matmul_persistent(M,
N,
K,
block_M,
block_N,
block_K,
threads,
num_stages,
dtype="float16",
accum_dtype="float",
use_persistent_primitive=True):
def matmul_persistent(
M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32, use_persistent_primitive=True
):
sm_num = driver.get_num_sms()
m_blocks = T.ceildiv(M, block_M)
n_blocks = T.ceildiv(N, block_N)
......@@ -100,8 +81,7 @@ def matmul_persistent(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
for bx, by in T.Persistent(
[T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id):
for bx, by in T.Persistent([T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id):
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared)
......@@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096):
num_stages = 3
persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
persistent_profiler = persistent_kernel.get_profiler(
tensor_supply_type=tilelang.TensorSupplyType.Randn)
persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Persistent GEMM: All check passed.")
persistent_latency = persistent_profiler.do_bench(warmup=500)
print(f"Persistent GEMM Latency: {persistent_latency} ms")
print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops")
non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads,
num_stages)
non_persistent_profiler = non_persistent_kernel.get_profiler(
tensor_supply_type=tilelang.TensorSupplyType.Randn)
non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
non_persistent_profiler = non_persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Non-Persistent GEMM: All check passed.")
non_persistent_latency = non_persistent_profiler.do_bench(warmup=500)
......@@ -151,9 +128,9 @@ def main(M=4096, N=4096, K=4096):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=8192, help='M dimension')
parser.add_argument('--N', type=int, default=8192, help='N dimension')
parser.add_argument('--K', type=int, default=8192, help='K dimension')
parser.add_argument("--M", type=int, default=8192, help="M dimension")
parser.add_argument("--N", type=int, default=8192, help="N dimension")
parser.add_argument("--K", type=int, default=8192, help="K dimension")
args = parser.parse_args()
M, N, K = args.M, args.N, args.K
main(M, N, K)
......@@ -3,8 +3,7 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def gemm_schedule(
A: T.Tensor((M, K), dtype),
......
......@@ -17,10 +17,8 @@ def supply_prog(args):
a_param, b_param = args
M, K = a_param.shape
N, _ = b_param.shape
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
return [a, b]
......@@ -35,10 +33,9 @@ def get_configs():
valid_configs = []
for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks,
num_stages, num_threads, k_packs,
gemm_types):
valid_configs.append({
for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types):
valid_configs.append(
{
"block_M": m,
"block_N": n,
"block_K": k,
......@@ -46,20 +43,18 @@ def get_configs():
"num_threads": t,
"k_pack": kp,
"gemm_type": gemm_type,
})
}
)
return valid_configs
@tilelang.autotune(
configs=get_configs(),
cache_input_tensors=True,
ref_prog=ref_program,
manual_check_prog=manual_check_prog,
supply_prog=supply_prog)
configs=get_configs(), cache_input_tensors=True, ref_prog=ref_program, manual_check_prog=manual_check_prog, supply_prog=supply_prog
)
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = "float8_e4m3fnuz"
accum_dtype = "float"
dtype = T.float8_e4m3fnuz
accum_dtype = T.float32
@T.prim_func
def gemm_fp8_rs(
......@@ -67,8 +62,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -77,13 +71,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_local)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_local,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N])
......@@ -93,8 +81,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -103,13 +90,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N])
......@@ -123,10 +104,8 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
def test_gemm_fp8(M, N, K):
kernel = fp8_matmul(M, N, K)
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
c = kernel(a, b)
ref_c = ref_program(a, b)
torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
......
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
def calc_diff(x, y):
......@@ -12,8 +11,7 @@ def calc_diff(x, y):
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32):
@T.prim_func
def gemm_fp8(
A: T.Tensor((M, K), dtype),
......@@ -37,12 +35,12 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype)
torch_dtype = T.dtype(dtype).as_torch()
kernel = matmul(M, N, K, 128, 128, 64, dtype)
a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
a = torch.randn(M, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype)
b = torch.randn(N, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype)
c = kernel(a, b)
......@@ -57,8 +55,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main():
test_gemm_fp8(1024, 1024, 1024, 'float8_e4m3')
test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2')
test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2)
if __name__ == "__main__":
......
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32):
# for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128.
# if block_K < 128, promote after 128/block_K iters.
# if block_K > 128, promote after every iter.
......@@ -55,18 +54,18 @@ def calc_diff(x, y):
def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype)
torch_dtype = T.dtype(dtype).as_torch()
kernel = matmul(M, N, K, 128, 128, 64, dtype)
a = torch.rand(M, K, dtype=torch.float16, device='cuda')
a = torch.rand(M, K, dtype=torch.float16, device="cuda")
a = (100 * (2 * a - 1)).to(dtype=torch_dtype)
b = torch.rand(N, K, dtype=torch.float16, device='cuda')
b = torch.rand(N, K, dtype=torch.float16, device="cuda")
b = (100 * (2 * b - 1)).to(dtype=torch_dtype)
c = kernel(a, b)
ref_c = (a.float() @ b.float().T)
ref_c = a.float() @ b.float().T
diff = calc_diff(c, ref_c)
print(f"diff: {diff}")
......@@ -74,8 +73,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main():
test_gemm_fp8(1024, 1024, 8192, 'float8_e4m3')
test_gemm_fp8(1024, 1024, 8192, 'float8_e5m2')
test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2)
if __name__ == "__main__":
......
......@@ -5,7 +5,8 @@ from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type
......@@ -38,21 +39,26 @@ def tl_matmul(
accum_dtype,
):
assert in_dtype in [
"float16",
"float8_e4m3",
"float8_e5m2",
"int8",
T.float16,
T.float8_e4m3fn,
T.float8_e5m2,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
if out_dtype == "int32" or is_float8:
is_float8 = in_dtype in [
T.float8_e4m3fn,
T.float8_e5m2,
T.float8_e4m3fn,
T.float8_e5m2fnuz,
]
if out_dtype == T.int32 or is_float8:
micro_size_k = 32
# This is a debug config
......@@ -60,7 +66,7 @@ def tl_matmul(
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
# Pipeline Stage
......@@ -110,7 +116,6 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -118,10 +123,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -129,7 +136,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -139,7 +145,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......@@ -215,8 +220,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def main():
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32)
assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32)
if __name__ == "__main__":
......
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm_v2(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=(k == 0),
)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 64, 256, 32
trans_A, trans_B = False, True
num_stages = 2
threads = 256
for tvm_fp8_dtype in [T.float8_e4m3fn, T.float8_e5m2]:
for tvm_acc_dtype in [T.float16, T.float32]: # , torch.float16]:
torch_fp8_dtype = map_torch_type(tvm_fp8_dtype)
torch_acc_dtype = map_torch_type(tvm_acc_dtype)
print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}")
in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype
func = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True,
},
)
# jit_kernel.export_ptx("./dump.ptx")
# jit_kernel.export_sources("./dump.cu")
a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype)
b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype)
c = jit_kernel(a, b)
ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float()
c = c.float()
diff = calc_diff(c, ref_c)
# assert diff < 1e-3, f"{diff}"
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}")
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms")
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS")
......@@ -40,19 +40,19 @@ import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, K), "bfloat16"),
B: T.Tensor((N, K), "bfloat16"),
C: T.Tensor((M, N), "bfloat16"),
A: T.Tensor((M, K), T.bfloat16),
B: T.Tensor((N, K), T.bfloat16),
C: T.Tensor((M, N), T.bfloat16),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
# 1. Allocate memory buffers
A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory
B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory
C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory
A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) # A matrix shared memory
B_shared = T.alloc_shared((block_N, block_K), T.bfloat16) # B matrix shared memory
C_tmem = T.alloc_tmem([block_M, block_N], T.float) # TCGEN5MMA output to Tensor Memory
mbar = T.alloc_barrier(1) # mbarrier synchronization primitive
C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage
C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory
C_local = T.alloc_fragment((block_M, block_N), T.float) # Register storage
C_shared = T.alloc_shared((block_M, block_N), T.bfloat16) # Output shared memory
# 2. Main computation loop
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
......
......@@ -4,8 +4,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -62,7 +61,8 @@ jit_kernel = tilelang.compile(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
......
......@@ -40,15 +40,7 @@ def matmul(
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
......@@ -62,12 +54,11 @@ def matmul(
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128
trans_A, trans_B = False, True
in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float
num_stages = 2
threads = 256
func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
......@@ -75,7 +66,8 @@ jit_kernel = tilelang.compile(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
print(jit_kernel.get_kernel_source())
......@@ -88,4 +80,4 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS")
print(f"Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS")
import argparse
import tilelang
import tilelang.language as T
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import randn_semi_sparse
from tilelang.utils.tensor import torch_assert_close
from triton.testing import do_bench
import torch
torch.manual_seed(42)
DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 64,
"num_stages": 1,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
T.float16: {
"block_M": 256,
"block_N": 128,
"block_K": 64,
"num_stages": 2,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
},
"h20": {
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
T.float16: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
},
}
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16_custom_compress(
M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout
):
e_factor, e_dtype = (16, T.int16)
@T.prim_func
def gemm_sp_fp16_custom_compress(
A_sparse: T.Tensor((M, K // 2), T.float16),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), T.float16),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), T.float16)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), T.float16)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if use_cutlass_layout:
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K),
}
)
T.clear(C_local)
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return gemm_sp_fp16_custom_compress
def torch_compress(dense):
"""
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
"""
if dense.dim() != 2:
raise RuntimeError(f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor")
m, k = dense.shape
meta_dtype = torch.int8
if dense.dtype == torch.int8:
meta_dtype = torch.int32
elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
meta_dtype = torch.int16
else:
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
if quadbits_per_meta_elem not in (4, 8):
raise RuntimeError("Invalid number of elements per meta element calculated")
if meta_dtype == torch.int32:
if m % 16 != 0:
raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 16")
else:
if m % 32 != 0:
raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32")
if k % (4 * quadbits_per_meta_elem) != 0:
raise RuntimeError(f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}")
if dense.dtype != torch.float:
ksparse = 4
dense_4 = dense.view(-1, k // ksparse, ksparse)
m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1)
else:
ksparse = 2
dense_2 = dense.view(-1, k // ksparse, ksparse)
m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1)
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0 = m0 & m1
expr1 = ~m0 & m1
expr2 = ~m0 & ~m1
bit0 = expr1
bit1 = expr2
bit2 = expr0 | expr2 | m3
bit3 = expr1 | ~m1
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
if dense.dtype != torch.float:
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else:
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
if quadbits_per_meta_elem == 4:
meta = meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12)
elif quadbits_per_meta_elem == 8:
meta = (
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12)
| (meta_n[:, :, 4] << 16)
| (meta_n[:, :, 5] << 20)
| (meta_n[:, :, 6] << 24)
| (meta_n[:, :, 7] << 28)
)
return (sparse, meta)
def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
assert meta.dtype is torch.int16
groups_per_meta = 16 // 4 # 4 groups per uint16
out = []
for g in range(groups_per_meta):
group_bits = (meta >> (g * 4)) & 0xF
idx0 = group_bits & 0x3
idx1 = (group_bits >> 2) & 0x3
out.append(torch.stack([idx0, idx1], dim=-1))
return torch.concat(out, dim=-1).view(meta.shape[0], -1)
@tilelang.jit(
out_idx=[1, 2],
pass_configs={
tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
},
)
def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
e_factor, e_dtype = ARCH_INFO["8.0"]
e_K = K // e_factor
elem, group = 2, 4
assert M % block_M == 0, "M must be divisible by block_M"
assert K % block_K == 0, "K must be divisible by block_K"
assert K % e_factor == 0, "K must be divisible by e_factor"
assert block_K % e_factor == 0, "block_K must be divisible by e_factor"
@T.prim_func
def kernel(
A: T.Tensor((M, K), dtype),
A_sp: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, e_K), e_dtype),
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
if use_cutlass_layout:
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K),
}
)
T.clear(A_sp_shared)
T.clear(E_shared)
# TODO: alloc_var seems buggy here
non_zero_cnt = T.alloc_local((1,), dtype=T.uint8)
non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8)
T.copy(A[bx * block_M, by * block_K], A_shared)
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
a_k = g_i * group
non_zero_cnt[0] = 0
for i in range(elem):
non_zero_elt_log_idx[i] = 0
for i in range(group):
val = A_shared[tm, a_k + i]
if val != 0.0:
non_zero_elt_log_idx[non_zero_cnt[0]] = i
A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
non_zero_cnt[0] += 1
# TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main
if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
non_zero_elt_log_idx[0] = 0
non_zero_elt_log_idx[1] = 3
A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
A_sp_shared[tm, a_k // 2] = 0.0
elif non_zero_cnt[0] == 1:
A_sp_shared[tm, a_k // 2 + 1] = 0
non_zero_elt_log_idx[1] = 3
for i in T.serial(elem):
val = non_zero_elt_log_idx[i]
E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i)
T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
return kernel
def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor")
parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference")
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16_custom_compress(
args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype], use_cutlass_layout=args.use_cutlass_layout
)
a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half)
b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half)
if args.use_torch_compressor:
assert not args.use_cutlass_layout, "torch sparse must be used with naive layout"
a_sparse, e = torch_compress(a)
else:
a_sparse, e = compress_kernel(args.m, args.k, 32, 32, T.float16, use_cutlass_layout=args.use_cutlass_layout)(a)
c = kernel(a_sparse, e, b)
ref_c = a @ b
assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3)
print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}")
latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b)
total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s")
if __name__ == "__main__":
main()
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import argparse
import tilelang
import tilelang.language as T
from tilelang.layout import make_metadata_layout
from tilelang.layout import make_cutlass_metadata_layout
from tilelang.utils.sparse import compress, randn_semi_sparse
from tilelang.contrib import nvcc
from triton.testing import do_bench
......@@ -14,86 +12,79 @@ import torch
arch = nvcc.get_target_compute_version()
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
default_config = { # take best config from autotune script
DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 64,
'num_stages': 1,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 64,
"num_stages": 1,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
T.float16: {
"block_M": 256,
"block_N": 128,
"block_K": 64,
"num_stages": 2,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
'float16': {
'block_M': 256,
'block_N': 128,
'block_K': 64,
'num_stages': 2,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
},
"h20": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
T.float16: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
},
'float16': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}
}
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
enable_rasterization):
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization):
e_factor, e_dtype = ARCH_INFO[arch]
@T.prim_func
def gemm_sp_fp16(
A_sparse: T.Tensor((M, K // 2), 'float16'),
A_sparse: T.Tensor((M, K // 2), T.float16),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), 'float16'),
B: T.Tensor((K, N), T.float16),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16')
A_shared = T.alloc_shared((block_M, block_K // 2), T.float16)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), 'float16')
B_shared = T.alloc_shared((block_K, block_N), T.float16)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
backend="cutlass",
block_k=block_K,
arch=arch),
})
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, block_k=block_K, arch=arch),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, block_k=block_K, arch=arch),
}
)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
......@@ -111,25 +102,15 @@ def main():
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True)
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype,
**default_config[args.cfg][args.accum_dtype])
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype])
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half)
b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half)
a_sparse, e = compress(
a,
transposed=False,
block_k=default_config[args.cfg][args.accum_dtype]['block_K'],
arch=arch)
a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]["block_K"], arch=arch)
c = kernel(a_sparse, e, b)
ref_c = a @ b
......@@ -144,8 +125,8 @@ def main():
total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s")
if __name__ == "__main__":
......
import tilelang.testing
import example_dynamic
import example_custom_compress
import example_gemm_sp
def test_example_dynamic():
example_dynamic.main(M=1024, N=1024, K=1024)
def test_example_custom_compress():
example_custom_compress.main()
def test_example_gemm_sp():
example_gemm_sp.main()
if __name__ == "__main__":
......
......@@ -3,17 +3,7 @@ import tilelang.language as T
@tilelang.jit
def matmul(M,
N,
K,
block_M,
block_N,
block_K,
split_k,
dtype="float16",
accum_dtype="float",
out_dtype="float32"):
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32):
splitK = K // split_k
@T.prim_func
......@@ -22,8 +12,7 @@ def matmul(M,
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
......
......@@ -3,17 +3,7 @@ import tilelang.language as T
@tilelang.jit
def matmul(M,
N,
K,
block_M,
block_N,
block_K,
split_k,
dtype="float16",
accum_dtype="float",
out_dtype="float32"):
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32):
splitK = K // split_k
@T.prim_func
......@@ -22,8 +12,7 @@ def matmul(M,
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
......
......@@ -39,7 +39,7 @@ total_tiles = num_block_m * num_block_n
# Two-tile SK + DP
streamk_tiles = total_tiles % streamk_programs
if (total_tiles - streamk_tiles > streamk_programs): # (total_tiles // total_programs > 1)
if total_tiles - streamk_tiles > streamk_programs: # (total_tiles // total_programs > 1)
streamk_tiles += streamk_programs
blocking_tiles = total_tiles - streamk_tiles
......@@ -87,8 +87,8 @@ def tl_matmul_streamk(
C: T.Tensor,
C_local: T.LocalBuffer,
):
start_iter = T.alloc_fragment((1,), "int32", "local")
end_iter = T.alloc_fragment((1,), "int32", "local")
start_iter = T.alloc_fragment((1,), T.int32, "local")
end_iter = T.alloc_fragment((1,), T.int32, "local")
start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles)
last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles)
......@@ -135,7 +135,6 @@ def tl_matmul_streamk(
C: T.Tensor,
C_local: T.LocalBuffer,
):
for p in T.serial(sm_patition_factor):
tile_id = pid + streamk_tiles + p * total_sm
pid_m = tile_id // T.ceildiv(N, block_N)
......@@ -155,7 +154,6 @@ def tl_matmul_streamk(
C: T.Tensor((M, N), dtypeC),
):
with T.Kernel(streamk_programs, threads=threads) as pid:
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB)
......@@ -181,9 +179,9 @@ def main():
BLOCK_SIZE_K,
False,
True,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
2,
64,
)
......
......@@ -17,10 +17,9 @@ def naive_gemv(
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float,
):
@T.prim_func
def main(
A: T.Tensor((K,), dtype),
......@@ -38,8 +37,7 @@ def naive_gemv(
A_shared[tk] = A[bk * BLOCK_K + tk]
B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
for tk in T.serial(BLOCK_K):
C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn,
tk].astype(accum_dtype)
C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, tk].astype(accum_dtype)
C[bn * BLOCK_N + tn] = C_reg[0]
return main
......@@ -51,10 +49,9 @@ def naive_splitk_gemv(
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float,
):
@T.prim_func
def main(
A: T.Tensor((K,), dtype),
......@@ -88,8 +85,8 @@ def splitk_gemv(
BLOCK_N: int,
BLOCK_K: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float,
):
TILE_K = T.ceildiv(BLOCK_K, reduce_threads)
......@@ -127,8 +124,8 @@ def splitk_gemv_vectorized(
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float,
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
......@@ -168,8 +165,8 @@ def splitk_gemv_vectorized_tvm(
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float,
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
......@@ -209,7 +206,8 @@ def splitk_gemv_vectorized_tvm(
C_reduced[0],
tk,
dtype="handle",
))
)
)
C[bn * BLOCK_N + tn] = C_reduced[0]
......@@ -218,10 +216,8 @@ def splitk_gemv_vectorized_tvm(
def get_block_template_configs():
iter_params = dict(
block_M=[2, 4, 8, 32, 64, 128],
block_N=[2, 4, 8, 32, 64, 128],
num_stages=[0, 1, 2, 3, 4],
threads=[32, 64, 128, 256])
block_M=[2, 4, 8, 32, 64, 128], block_N=[2, 4, 8, 32, 64, 128], num_stages=[0, 1, 2, 3, 4], threads=[32, 64, 128, 256]
)
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
......@@ -237,18 +233,11 @@ def get_block_template_configs():
},
out_idx=[2],
)
def gemv_alloc_reducer(M,
N,
block_M=128,
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16",
accum_dtype: str = "float"):
def gemv_alloc_reducer(
M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float
):
@T.prim_func
def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M,
dtype)): # type: ignore
def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m:
o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all")
T.clear(o_reducer)
......@@ -287,8 +276,8 @@ def get_autotuned_kernel(
BLOCK_N=None,
reduce_threads=None,
):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
......@@ -327,17 +316,18 @@ def get_autotuned_kernel(
C_reduced[0],
tk,
dtype="handle",
))
)
)
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
def check_correctness_and_bench(kernel, N, K, bench_ref=True):
def check_correctness_and_bench(kernel, N, K, do_bench=True):
profiler = kernel.get_profiler()
profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2)
if bench_ref:
if do_bench:
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=50)
print(f"Torch Latency: {latency} ms")
latency = profiler.do_bench(kernel, warmup=50)
......@@ -350,16 +340,16 @@ def main(do_bench: bool = True):
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
args, _ = parser.parse_known_args()
N, K = args.n, args.k
check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K)
check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K)
check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K)
check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K)
check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K)
check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K)
check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K, do_bench=do_bench)
check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench)
print("Test passed!")
if not do_bench:
if do_bench:
best_result = get_autotuned_kernel(N, K)
best_config = best_result.config
kernel = splitk_gemv_vectorized_tvm(N, K, **best_config)
......
import tilelang.testing
import example_gemv
......@@ -8,4 +6,4 @@ def test_example_gemv():
if __name__ == "__main__":
tilelang.testing.main()
test_example_gemv()
......@@ -5,67 +5,45 @@ import tilelang
import tilelang.language as T
@tilelang.jit(
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def grouped_gemm_fwd(batch_sum,
batch_count,
K,
N,
block_M,
block_N,
block_K,
num_stages=2,
threads=128,
dtype="float16"):
@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
b (torch.Tensor): Input tensor of shape (G, K, N).
"""
accum_dtype = "float32"
accum_dtype = T.float32
@T.prim_func
def kernel(
A: T.Tensor([batch_sum, K], dtype), # type: ignore
B: T.Tensor([batch_count, K, N], dtype), # type: ignore
C: T.Tensor([batch_sum, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore
batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore
):
with T.Kernel(
T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N),
threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared([block_M, block_K], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
cur_batch_idx = T.alloc_local([1], "int32")
cur_batch_size = T.alloc_local([1], "int32")
cur_batch_idx = T.alloc_local([1], T.int32)
cur_batch_size = T.alloc_local([1], T.int32)
m_start_padded = bx * block_M
for i in range(batch_count):
in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i])
in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i]
cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0])
cur_batch_size[0] = batch_sizes[cur_batch_idx[0]]
m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[
cur_batch_idx[0]]
actual_rows = T.max(
0,
T.min(block_M,
cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]]
actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared)
T.copy(
B[cur_batch_idx[0], k * block_K:(k + 1) * block_K,
by * block_N:(by + 1) * block_N], B_shared)
T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared)
T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
for i, j in T.Parallel(block_M, block_N):
......@@ -76,7 +54,6 @@ def grouped_gemm_fwd(batch_sum,
class _GroupedGEMM(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b, batch_sizes):
block_M = 64
......@@ -99,15 +76,11 @@ class _GroupedGEMM(torch.autograd.Function):
for i in range(batch_count - 1):
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i])
for i in range(batch_count - 1):
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] +
math.ceil((batch_sizes[i] + 1) / padding_M) *
padding_M)
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes[i] + 1) / padding_M) * padding_M)
batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32)
batch_padded_offsets = torch.tensor(
batch_padded_offsets_list, device=a.device, dtype=torch.int32)
batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=a.device, dtype=torch.int32)
kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K,
num_stages, threads)
kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages, threads)
o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets)
ctx.save_for_backward(a, b, batch_sizes, batch_offsets)
......@@ -135,8 +108,7 @@ class _GroupedGEMM(torch.autograd.Function):
return x
A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)]
kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K,
num_stages, threads)
kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, num_stages, threads)
dB = kernel(A, grad_output, batch_sizes, batch_offsets)
return None, dB, None
......@@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for i in range(batch_count - 1):
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i])
for i in range(batch_count - 1):
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] +
math.ceil((batch_sizes_list[i] + 1) / padding_M) *
padding_M)
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_M)
A = torch.randn(batch_sum, K, device=device, dtype=dtype)
B = torch.randn(batch_count, K, M, device=device, dtype=dtype)
C = torch.empty(batch_sum, M, device=device, dtype=dtype)
......@@ -187,40 +157,24 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets
@tilelang.jit(
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def grouped_gemm_bwd(batch_sum,
batch_count,
M,
N,
block_M,
block_N,
block_K,
num_stages=2,
threads=128,
dtype="float16"):
@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
b (torch.Tensor): Input tensor of shape (G, K, N).
"""
accum_dtype = "float32"
accum_dtype = T.float32
@T.prim_func
def kernel(
A: T.Tensor([batch_sum, M], dtype), # type: ignore
B: T.Tensor([batch_sum, N], dtype), # type: ignore
C: T.Tensor([batch_count, M, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore
batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore
):
with T.Kernel(
T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count,
threads=threads) as (bx, by, bz):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared([block_K, block_M], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
......@@ -228,13 +182,9 @@ def grouped_gemm_bwd(batch_sum,
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages):
for i, j in T.Parallel(block_K, block_M):
A_shared[i, j] = T.if_then_else(
i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i,
bx * block_M + j], 0)
A_shared[i, j] = T.if_then_else(i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, bx * block_M + j], 0)
for i, j in T.Parallel(block_K, block_N):
B_shared[i, j] = T.if_then_else(
i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i,
by * block_N + j], 0)
B_shared[i, j] = T.if_then_else(i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, by * block_N + j], 0)
T.gemm(A_shared, B_shared, C_local, transpose_A=True)
T.copy(C_local, C[bz, bx * block_M, by * block_N])
......@@ -242,23 +192,12 @@ def grouped_gemm_bwd(batch_sum,
return kernel
def run_tilelang_grouped_gemm(batch_sizes_list,
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages=2,
threads=128,
profile=False):
def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False):
padding_M = block_M
device = torch.device("cuda")
dtype = torch.float16
A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(
batch_sizes_list, K, M, False, padding_M, device, dtype)
A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, False, padding_M, device, dtype)
A.requires_grad_(False)
B.requires_grad_(True)
......@@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
O.backward(dO, retain_graph=True)
dB, B.grad = B.grad.clone(), None
if (
torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and \
torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2)
):
if torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2):
print("✅ Tilelang and Torch match")
else:
print("❌ Tilelang and Torch mismatch")
......@@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes')
parser.add_argument('--K', type=int, default=8192, help='reduce dim')
parser.add_argument('--M', type=int, default=8192, help='output dim')
parser.add_argument('--trans_b', action="store_true", help="transpose B")
parser.add_argument('--profile', action="store_true", help="profile")
parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes")
parser.add_argument("--K", type=int, default=8192, help="reduce dim")
parser.add_argument("--M", type=int, default=8192, help="output dim")
parser.add_argument("--trans_b", action="store_true", help="transpose B")
parser.add_argument("--profile", action="store_true", help="profile")
args = parser.parse_args()
batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")]
......@@ -301,14 +236,4 @@ if __name__ == "__main__":
num_stages = 2
threads = 256
run_tilelang_grouped_gemm(
batch_sizes_list,
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages,
threads,
profile=args.profile)
run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment