Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -4,12 +4,11 @@ import tilelang.language as T ...@@ -4,12 +4,11 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def gemm( def gemm(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
......
...@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20): ...@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20):
num_stages, num_stages,
thread_num, thread_num,
enable_rasterization, enable_rasterization,
)) )
)
configs = [ configs = [
{ {
...@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20): ...@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20):
"num_stages": c[3], "num_stages": c[3],
"thread_num": c[4], "thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat "enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs }
for c in _configs
] ]
return configs return configs
def get_best_config(M, N, K, with_roller=False): def get_best_config(M, N, K, with_roller=False):
def kernel( def kernel(
block_M=None, block_M=None,
block_N=None, block_N=None,
...@@ -120,12 +121,11 @@ def get_best_config(M, N, K, with_roller=False): ...@@ -120,12 +121,11 @@ def get_best_config(M, N, K, with_roller=False):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False): ...@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False):
return main return main
autotuner = AutoTuner.from_kernel( autotuner = (
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller))
.set_compile_args(
out_idx=[-1], out_idx=[-1],
target="auto", target="auto",
).set_profile_args( )
.set_profile_args(
supply_type=tl.TensorSupplyType.Integer, supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program, ref_prog=ref_program,
skip_check=False, skip_check=False,
) )
)
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
...@@ -167,52 +170,20 @@ def get_heuristic_config() -> dict: ...@@ -167,52 +170,20 @@ def get_heuristic_config() -> dict:
sm_version = sm_major * 10 + sm_minor sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}") print(f"CUDA device capability: {sm_version}")
if sm_version in {80}: if sm_version in {80}:
return { return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 2,
"thread_num": 128,
"enable_rasteration": True
}
elif sm_version in {90}: elif sm_version in {90}:
return { return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 64,
"num_stages": 3,
"thread_num": 256,
"enable_rasteration": True
}
else: else:
return { return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 0,
"thread_num": 128,
"enable_rasteration": True
}
@tl.jit(out_idx=[-1]) @tl.jit(out_idx=[-1])
def matmul(M, def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"):
N,
K,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
@T.prim_func @T.prim_func
def gemm_autotune( def gemm_autotune(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -236,11 +207,7 @@ def matmul(M, ...@@ -236,11 +207,7 @@ def matmul(M,
return gemm_autotune return gemm_autotune
def main(M: int = 4096, def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False):
N: int = 4096,
K: int = 4096,
use_autotune: bool = False,
with_roller: bool = False):
use_autotune = True use_autotune = True
if use_autotune: if use_autotune:
result = get_best_config(M, N, K, with_roller) result = get_best_config(M, N, K, with_roller)
...@@ -266,15 +233,7 @@ if __name__ == "__main__": ...@@ -266,15 +233,7 @@ if __name__ == "__main__":
parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N") parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K")
parser.add_argument( parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs")
"--use_autotune", parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space")
action="store_true",
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=False,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args() args = parser.parse_args()
main(args.m, args.n, args.k, args.use_autotune, args.with_roller) main(args.m, args.n, args.k, args.use_autotune, args.with_roller)
...@@ -4,7 +4,8 @@ import tilelang ...@@ -4,7 +4,8 @@ import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
...@@ -99,12 +100,11 @@ def tl_matmul( ...@@ -99,12 +100,11 @@ def tl_matmul(
@T.prim_func @T.prim_func
def gemm_intrinsics( def gemm_intrinsics(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -112,10 +112,12 @@ def tl_matmul( ...@@ -112,10 +112,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_shared: make_swizzle_layout(B_shared), A_shared: make_swizzle_layout(A_shared),
}) B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -123,7 +125,6 @@ def tl_matmul( ...@@ -123,7 +125,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -133,7 +134,6 @@ def tl_matmul( ...@@ -133,7 +134,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a(A_local, A_shared, ki) mma_emitter.ldmatrix_a(A_local, A_shared, ki)
......
...@@ -5,22 +5,12 @@ import argparse ...@@ -5,22 +5,12 @@ import argparse
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul_non_persistent(M, def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float"):
N,
K,
block_M,
block_N,
block_K,
threads,
num_stages,
dtype="float16",
accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -43,18 +33,9 @@ def matmul_non_persistent(M, ...@@ -43,18 +33,9 @@ def matmul_non_persistent(M,
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul_persistent(M, def matmul_persistent(
N, M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float", use_persistent_primitive=True
K, ):
block_M,
block_N,
block_K,
threads,
num_stages,
dtype="float16",
accum_dtype="float",
use_persistent_primitive=True):
sm_num = driver.get_num_sms() sm_num = driver.get_num_sms()
m_blocks = T.ceildiv(M, block_M) m_blocks = T.ceildiv(M, block_M)
n_blocks = T.ceildiv(N, block_N) n_blocks = T.ceildiv(N, block_N)
...@@ -63,9 +44,9 @@ def matmul_persistent(M, ...@@ -63,9 +44,9 @@ def matmul_persistent(M,
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(sm_num, threads=threads) as (block_id): with T.Kernel(sm_num, threads=threads) as (block_id):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -90,9 +71,9 @@ def matmul_persistent(M, ...@@ -90,9 +71,9 @@ def matmul_persistent(M,
@T.prim_func @T.prim_func
def main_persistent_primitive( def main_persistent_primitive(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(sm_num, threads=threads) as (block_id): with T.Kernel(sm_num, threads=threads) as (block_id):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -100,8 +81,7 @@ def matmul_persistent(M, ...@@ -100,8 +81,7 @@ def matmul_persistent(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), dtype)
for bx, by in T.Persistent( for bx, by in T.Persistent([T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id):
[T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id):
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared) T.copy(A[bx * block_M, k * block_K], A_shared)
...@@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096): ...@@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096):
num_stages = 3 num_stages = 3
persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
persistent_profiler = persistent_kernel.get_profiler( persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
tensor_supply_type=tilelang.TensorSupplyType.Randn)
persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Persistent GEMM: All check passed.") print("Persistent GEMM: All check passed.")
persistent_latency = persistent_profiler.do_bench(warmup=500) persistent_latency = persistent_profiler.do_bench(warmup=500)
print(f"Persistent GEMM Latency: {persistent_latency} ms") print(f"Persistent GEMM Latency: {persistent_latency} ms")
print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops") print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops")
non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
num_stages) non_persistent_profiler = non_persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
non_persistent_profiler = non_persistent_kernel.get_profiler(
tensor_supply_type=tilelang.TensorSupplyType.Randn)
non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Non-Persistent GEMM: All check passed.") print("Non-Persistent GEMM: All check passed.")
non_persistent_latency = non_persistent_profiler.do_bench(warmup=500) non_persistent_latency = non_persistent_profiler.do_bench(warmup=500)
...@@ -151,9 +128,9 @@ def main(M=4096, N=4096, K=4096): ...@@ -151,9 +128,9 @@ def main(M=4096, N=4096, K=4096):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=8192, help='M dimension') parser.add_argument("--M", type=int, default=8192, help="M dimension")
parser.add_argument('--N', type=int, default=8192, help='N dimension') parser.add_argument("--N", type=int, default=8192, help="N dimension")
parser.add_argument('--K', type=int, default=8192, help='K dimension') parser.add_argument("--K", type=int, default=8192, help="K dimension")
args = parser.parse_args() args = parser.parse_args()
M, N, K = args.M, args.N, args.K M, N, K = args.M, args.N, args.K
main(M, N, K) main(M, N, K)
...@@ -4,12 +4,11 @@ import tilelang.language as T ...@@ -4,12 +4,11 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def gemm_schedule( def gemm_schedule(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
......
...@@ -17,10 +17,8 @@ def supply_prog(args): ...@@ -17,10 +17,8 @@ def supply_prog(args):
a_param, b_param = args a_param, b_param = args
M, K = a_param.shape M, K = a_param.shape
N, _ = b_param.shape N, _ = b_param.shape
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
0.01).to(dtype=torch.float8_e4m3fnuz) b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
return [a, b] return [a, b]
...@@ -35,27 +33,24 @@ def get_configs(): ...@@ -35,27 +33,24 @@ def get_configs():
valid_configs = [] valid_configs = []
for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types):
num_stages, num_threads, k_packs, valid_configs.append(
gemm_types): {
valid_configs.append({ "block_M": m,
"block_M": m, "block_N": n,
"block_N": n, "block_K": k,
"block_K": k, "num_stages": stages,
"num_stages": stages, "num_threads": t,
"num_threads": t, "k_pack": kp,
"k_pack": kp, "gemm_type": gemm_type,
"gemm_type": gemm_type, }
}) )
return valid_configs return valid_configs
@tilelang.autotune( @tilelang.autotune(
configs=get_configs(), configs=get_configs(), cache_input_tensors=True, ref_prog=ref_program, manual_check_prog=manual_check_prog, supply_prog=supply_prog
cache_input_tensors=True, )
ref_prog=ref_program,
manual_check_prog=manual_check_prog,
supply_prog=supply_prog)
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type): def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = "float8_e4m3fnuz" dtype = "float8_e4m3fnuz"
...@@ -63,12 +58,11 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa ...@@ -63,12 +58,11 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
@T.prim_func @T.prim_func
def gemm_fp8_rs( def gemm_fp8_rs(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_local = T.alloc_fragment((block_M, block_K), dtype) A_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -77,24 +71,17 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa ...@@ -77,24 +71,17 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_local) T.copy(A[by * block_M, k * block_K], A_local)
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm( T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
A_local,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C[by * block_M, bx * block_N])
@T.prim_func @T.prim_func
def gemm_fp8_ss( def gemm_fp8_ss(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -103,13 +90,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa ...@@ -103,13 +90,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm( T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
A_shared,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C[by * block_M, bx * block_N])
...@@ -123,10 +104,8 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa ...@@ -123,10 +104,8 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
def test_gemm_fp8(M, N, K): def test_gemm_fp8(M, N, K):
kernel = fp8_matmul(M, N, K) kernel = fp8_matmul(M, N, K)
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
0.01).to(dtype=torch.float8_e4m3fnuz) b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
c = kernel(a, b) c = kernel(a, b)
ref_c = ref_program(a, b) ref_c = ref_program(a, b)
torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2) torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
......
...@@ -13,12 +13,11 @@ def calc_diff(x, y): ...@@ -13,12 +13,11 @@ def calc_diff(x, y):
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
@T.prim_func @T.prim_func
def gemm_fp8( def gemm_fp8(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -41,8 +40,8 @@ def test_gemm_fp8(M, N, K, dtype): ...@@ -41,8 +40,8 @@ def test_gemm_fp8(M, N, K, dtype):
kernel = matmul(M, N, K, 128, 128, 64, dtype) kernel = matmul(M, N, K, 128, 128, 64, dtype)
a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) a = torch.randn(M, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype)
b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) b = torch.randn(N, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype)
c = kernel(a, b) c = kernel(a, b)
...@@ -57,8 +56,8 @@ def test_gemm_fp8(M, N, K, dtype): ...@@ -57,8 +56,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main(): def main():
test_gemm_fp8(1024, 1024, 1024, 'float8_e4m3') test_gemm_fp8(1024, 1024, 1024, "float8_e4m3")
test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2') test_gemm_fp8(1024, 1024, 1024, "float8_e5m2")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -13,9 +13,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): ...@@ -13,9 +13,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
@T.prim_func @T.prim_func
def gemm_fp8_2xAcc( def gemm_fp8_2xAcc(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -59,14 +59,14 @@ def test_gemm_fp8(M, N, K, dtype): ...@@ -59,14 +59,14 @@ def test_gemm_fp8(M, N, K, dtype):
kernel = matmul(M, N, K, 128, 128, 64, dtype) kernel = matmul(M, N, K, 128, 128, 64, dtype)
a = torch.rand(M, K, dtype=torch.float16, device='cuda') a = torch.rand(M, K, dtype=torch.float16, device="cuda")
a = (100 * (2 * a - 1)).to(dtype=torch_dtype) a = (100 * (2 * a - 1)).to(dtype=torch_dtype)
b = torch.rand(N, K, dtype=torch.float16, device='cuda') b = torch.rand(N, K, dtype=torch.float16, device="cuda")
b = (100 * (2 * b - 1)).to(dtype=torch_dtype) b = (100 * (2 * b - 1)).to(dtype=torch_dtype)
c = kernel(a, b) c = kernel(a, b)
ref_c = (a.float() @ b.float().T) ref_c = a.float() @ b.float().T
diff = calc_diff(c, ref_c) diff = calc_diff(c, ref_c)
print(f"diff: {diff}") print(f"diff: {diff}")
...@@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype): ...@@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main(): def main():
test_gemm_fp8(1024, 1024, 8192, 'float8_e4m3') test_gemm_fp8(1024, 1024, 8192, "float8_e4m3")
test_gemm_fp8(1024, 1024, 8192, 'float8_e5m2') test_gemm_fp8(1024, 1024, 8192, "float8_e5m2")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -5,7 +5,8 @@ from tvm import DataType ...@@ -5,7 +5,8 @@ from tvm import DataType
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type from tilelang.utils.tensor import map_torch_type
...@@ -110,12 +111,11 @@ def tl_matmul( ...@@ -110,12 +111,11 @@ def tl_matmul(
@T.prim_func @T.prim_func
def gemm_fp8_intrinsic( def gemm_fp8_intrinsic(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -123,10 +123,12 @@ def tl_matmul( ...@@ -123,10 +123,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_shared: make_swizzle_layout(B_shared), A_shared: make_swizzle_layout(A_shared),
}) B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -134,7 +136,6 @@ def tl_matmul( ...@@ -134,7 +136,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -144,7 +145,6 @@ def tl_matmul( ...@@ -144,7 +145,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
......
...@@ -26,9 +26,9 @@ def matmul( ...@@ -26,9 +26,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -121,6 +121,4 @@ for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]: ...@@ -121,6 +121,4 @@ for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]:
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
latency = profiler.do_bench() latency = profiler.do_bench()
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms") print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms")
print( print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS")
f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS"
)
...@@ -5,12 +5,11 @@ import tilelang.language as T ...@@ -5,12 +5,11 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
...@@ -62,7 +61,8 @@ jit_kernel = tilelang.compile( ...@@ -62,7 +61,8 @@ jit_kernel = tilelang.compile(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
print(jit_kernel.get_kernel_source()) print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data # 3. Test the kernel in Python with PyTorch data
import torch import torch
......
...@@ -25,9 +25,9 @@ def matmul( ...@@ -25,9 +25,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -40,15 +40,7 @@ def matmul( ...@@ -40,15 +40,7 @@ def matmul(
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm( T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0)
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2) T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local) T.copy(C_tmem, C_local)
...@@ -66,8 +58,7 @@ in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float" ...@@ -66,8 +58,7 @@ in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
num_stages = 2 num_stages = 2
threads = 256 threads = 256
func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads)
accum_dtype, num_stages, threads)
jit_kernel = tilelang.compile( jit_kernel = tilelang.compile(
func, func,
out_idx=[2], out_idx=[2],
...@@ -75,7 +66,8 @@ jit_kernel = tilelang.compile( ...@@ -75,7 +66,8 @@ jit_kernel = tilelang.compile(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
print(jit_kernel.get_kernel_source()) print(jit_kernel.get_kernel_source())
...@@ -88,4 +80,4 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) ...@@ -88,4 +80,4 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
latency = profiler.do_bench() latency = profiler.do_bench()
print(f"Latency: {latency} ms") print(f"Latency: {latency} ms")
print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS") print(f"Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS")
...@@ -17,77 +17,76 @@ torch.manual_seed(42) ...@@ -17,77 +17,76 @@ torch.manual_seed(42)
DEFAULT_CONFIG = { # take best config from autotune script DEFAULT_CONFIG = { # take best config from autotune script
"4090": { "4090": {
'float': { "float": {
'block_M': 128, "block_M": 128,
'block_N': 64, "block_N": 64,
'block_K': 64, "block_K": 64,
'num_stages': 1, "num_stages": 1,
'thread_num': 128, "thread_num": 128,
'policy': T.GemmWarpPolicy.Square, "policy": T.GemmWarpPolicy.Square,
'enable_rasterization': True "enable_rasterization": True,
},
"float16": {
"block_M": 256,
"block_N": 128,
"block_K": 64,
"num_stages": 2,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
}, },
'float16': {
'block_M': 256,
'block_N': 128,
'block_K': 64,
'num_stages': 2,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}, },
"h20": { "h20": {
'float': { "float": {
'block_M': 128, "block_M": 128,
'block_N': 64, "block_N": 64,
'block_K': 128, "block_K": 128,
'num_stages': 3, "num_stages": 3,
'thread_num': 128, "thread_num": 128,
'policy': T.GemmWarpPolicy.Square, "policy": T.GemmWarpPolicy.Square,
'enable_rasterization': True "enable_rasterization": True,
},
"float16": {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
}, },
'float16': { },
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}
} }
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, def matmul_sp_fp16_custom_compress(
thread_num, policy, enable_rasterization, use_cutlass_layout): M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout
):
e_factor, e_dtype = (16, "int16") e_factor, e_dtype = (16, "int16")
@T.prim_func @T.prim_func
def gemm_sp_fp16_custom_compress( def gemm_sp_fp16_custom_compress(
A_sparse: T.Tensor((M, K // 2), 'float16'), A_sparse: T.Tensor((M, K // 2), "float16"),
E: T.Tensor((M, K // e_factor), e_dtype), E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), 'float16'), B: T.Tensor((K, N), "float16"),
C: T.Tensor((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16') A_shared = T.alloc_shared((block_M, block_K // 2), "float16")
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), 'float16') B_shared = T.alloc_shared((block_K, block_N), "float16")
C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if use_cutlass_layout: if use_cutlass_layout:
T.annotate_layout({ T.annotate_layout(
E: {
make_cutlass_metadata_layout( E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K),
E, mma_dtype="float16", arch="8.0", block_k=block_K), E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared: }
make_cutlass_metadata_layout( )
E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
})
T.clear(C_local) T.clear(C_local)
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization) T.use_swizzle(panel_size=10, enable=enable_rasterization)
...@@ -108,8 +107,7 @@ def torch_compress(dense): ...@@ -108,8 +107,7 @@ def torch_compress(dense):
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout. A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
""" """
if dense.dim() != 2: if dense.dim() != 2:
raise RuntimeError( raise RuntimeError(f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor")
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor")
m, k = dense.shape m, k = dense.shape
...@@ -131,9 +129,7 @@ def torch_compress(dense): ...@@ -131,9 +129,7 @@ def torch_compress(dense):
if m % 32 != 0: if m % 32 != 0:
raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32") raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32")
if k % (4 * quadbits_per_meta_elem) != 0: if k % (4 * quadbits_per_meta_elem) != 0:
raise RuntimeError( raise RuntimeError(f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}")
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
)
if dense.dtype != torch.float: if dense.dtype != torch.float:
ksparse = 4 ksparse = 4
...@@ -194,19 +190,13 @@ def torch_compress(dense): ...@@ -194,19 +190,13 @@ def torch_compress(dense):
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else: else:
sparse = dense_2.gather(-1, sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
idxs0.unsqueeze(-1) // 2).view(
m, k // 2) # type: ignore[possibly-undefined]
meta_4 = idxs0 | (idxs1 << 2) meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
if quadbits_per_meta_elem == 4: if quadbits_per_meta_elem == 4:
meta = ( meta = meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12)
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12))
elif quadbits_per_meta_elem == 8: elif quadbits_per_meta_elem == 8:
meta = ( meta = (
meta_n[:, :, 0] meta_n[:, :, 0]
...@@ -216,7 +206,8 @@ def torch_compress(dense): ...@@ -216,7 +206,8 @@ def torch_compress(dense):
| (meta_n[:, :, 4] << 16) | (meta_n[:, :, 4] << 16)
| (meta_n[:, :, 5] << 20) | (meta_n[:, :, 5] << 20)
| (meta_n[:, :, 6] << 24) | (meta_n[:, :, 6] << 24)
| (meta_n[:, :, 7] << 28)) | (meta_n[:, :, 7] << 28)
)
return (sparse, meta) return (sparse, meta)
...@@ -234,9 +225,11 @@ def decode_metadata(meta: torch.Tensor) -> torch.Tensor: ...@@ -234,9 +225,11 @@ def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
@tilelang.jit( @tilelang.jit(
out_idx=[1, 2], pass_configs={ out_idx=[1, 2],
pass_configs={
tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
}) },
)
def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
e_factor, e_dtype = ARCH_INFO["8.0"] e_factor, e_dtype = ARCH_INFO["8.0"]
e_K = K // e_factor e_K = K // e_factor
...@@ -249,23 +242,21 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): ...@@ -249,23 +242,21 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
@T.prim_func @T.prim_func
def kernel( def kernel(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
A_sp: T.Tensor((M, K // 2), dtype), A_sp: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, e_K), e_dtype), E: T.Tensor((M, e_K), e_dtype),
): ):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
if use_cutlass_layout: if use_cutlass_layout:
T.annotate_layout({ T.annotate_layout(
E: {
make_cutlass_metadata_layout( E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K),
E, mma_dtype="float16", arch="8.0", block_k=block_K), E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared: }
make_cutlass_metadata_layout( )
E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
})
T.clear(A_sp_shared) T.clear(A_sp_shared)
T.clear(E_shared) T.clear(E_shared)
# TODO: alloc_var seems buggy here # TODO: alloc_var seems buggy here
...@@ -295,8 +286,7 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): ...@@ -295,8 +286,7 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
non_zero_elt_log_idx[1] = 3 non_zero_elt_log_idx[1] = 3
for i in T.serial(elem): for i in T.serial(elem):
val = non_zero_elt_log_idx[i] val = non_zero_elt_log_idx[i]
E_shared[tm, a_k // e_factor] |= T.shift_left( E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i)
val, 4 * (g_i % (e_factor // group)) + 2 * i)
T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
...@@ -304,41 +294,27 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): ...@@ -304,41 +294,27 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
def main(): def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument( parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor")
"--use_cutlass_layout", action='store_true', help="Use cutlass layout for E tensor") parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference")
parser.add_argument( parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
"--use_torch_compressor", action='store_true', help="Use torch sparse for reference")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
args = parser.parse_args() args = parser.parse_args()
kernel = matmul_sp_fp16_custom_compress( kernel = matmul_sp_fp16_custom_compress(
args.m, args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype], use_cutlass_layout=args.use_cutlass_layout
args.n, )
args.k,
args.accum_dtype,
**DEFAULT_CONFIG[args.cfg][args.accum_dtype],
use_cutlass_layout=args.use_cutlass_layout)
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half)
if args.use_torch_compressor: if args.use_torch_compressor:
assert not args.use_cutlass_layout, "torch sparse must be used with naive layout" assert not args.use_cutlass_layout, "torch sparse must be used with naive layout"
a_sparse, e = torch_compress(a) a_sparse, e = torch_compress(a)
else: else:
a_sparse, e = compress_kernel( a_sparse, e = compress_kernel(args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(a)
args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(
a)
c = kernel(a_sparse, e, b) c = kernel(a_sparse, e, b)
...@@ -346,9 +322,7 @@ def main(): ...@@ -346,9 +322,7 @@ def main():
assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3) torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3)
print( print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}")
f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}"
)
latency = do_bench(lambda: kernel(a_sparse, e, b)) latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b) ref_latency = do_bench(lambda: a @ b)
...@@ -356,8 +330,8 @@ def main(): ...@@ -356,8 +330,8 @@ def main():
total_flops = 2 * args.m * args.n * args.k total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9 tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s") print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s") print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -16,80 +16,77 @@ arch = nvcc.get_target_compute_version() ...@@ -16,80 +16,77 @@ arch = nvcc.get_target_compute_version()
DEFAULT_CONFIG = { # take best config from autotune script DEFAULT_CONFIG = { # take best config from autotune script
"4090": { "4090": {
'float': { "float": {
'block_M': 128, "block_M": 128,
'block_N': 64, "block_N": 64,
'block_K': 64, "block_K": 64,
'num_stages': 1, "num_stages": 1,
'thread_num': 128, "thread_num": 128,
'policy': T.GemmWarpPolicy.Square, "policy": T.GemmWarpPolicy.Square,
'enable_rasterization': True "enable_rasterization": True,
},
"float16": {
"block_M": 256,
"block_N": 128,
"block_K": 64,
"num_stages": 2,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
}, },
'float16': {
'block_M': 256,
'block_N': 128,
'block_K': 64,
'num_stages': 2,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}, },
"h20": { "h20": {
'float': { "float": {
'block_M': 128, "block_M": 128,
'block_N': 64, "block_N": 64,
'block_K': 128, "block_K": 128,
'num_stages': 3, "num_stages": 3,
'thread_num': 128, "thread_num": 128,
'policy': T.GemmWarpPolicy.Square, "policy": T.GemmWarpPolicy.Square,
'enable_rasterization': True "enable_rasterization": True,
},
"float16": {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
}, },
'float16': { },
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}
} }
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization):
enable_rasterization):
e_factor, e_dtype = ARCH_INFO[arch] e_factor, e_dtype = ARCH_INFO[arch]
@T.prim_func @T.prim_func
def gemm_sp_fp16( def gemm_sp_fp16(
A_sparse: T.Tensor((M, K // 2), 'float16'), A_sparse: T.Tensor((M, K // 2), "float16"),
E: T.Tensor((M, K // e_factor), e_dtype), E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), 'float16'), B: T.Tensor((K, N), "float16"),
C: T.Tensor((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16') A_shared = T.alloc_shared((block_M, block_K // 2), "float16")
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), 'float16') B_shared = T.alloc_shared((block_K, block_N), "float16")
C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local) T.clear(C_local)
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization) T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({ T.annotate_layout(
E: {
make_cutlass_metadata_layout( E: make_cutlass_metadata_layout(E, mma_dtype="float16", block_k=block_K, arch=arch),
E, mma_dtype="float16", block_k=block_K, arch=arch), E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", block_k=block_K, arch=arch),
E_shared: }
make_cutlass_metadata_layout( )
E_shared, mma_dtype="float16", block_k=block_K, arch=arch),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared) T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
...@@ -107,25 +104,15 @@ def main(): ...@@ -107,25 +104,15 @@ def main():
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument( parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
args = parser.parse_args() args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype])
**DEFAULT_CONFIG[args.cfg][args.accum_dtype])
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half)
a_sparse, e = compress( a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]["block_K"], arch=arch)
a,
transposed=False,
block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]['block_K'],
arch=arch)
c = kernel(a_sparse, e, b) c = kernel(a_sparse, e, b)
ref_c = a @ b ref_c = a @ b
...@@ -140,8 +127,8 @@ def main(): ...@@ -140,8 +127,8 @@ def main():
total_flops = 2 * args.m * args.n * args.k total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9 tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s") print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s") print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -3,27 +3,16 @@ import tilelang.language as T ...@@ -3,27 +3,16 @@ import tilelang.language as T
@tilelang.jit @tilelang.jit
def matmul(M, def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"):
N,
K,
block_M,
block_N,
block_K,
split_k,
dtype="float16",
accum_dtype="float",
out_dtype="float32"):
splitK = K // split_k splitK = K // split_k
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype)
......
...@@ -3,27 +3,16 @@ import tilelang.language as T ...@@ -3,27 +3,16 @@ import tilelang.language as T
@tilelang.jit @tilelang.jit
def matmul(M, def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"):
N,
K,
block_M,
block_N,
block_K,
split_k,
dtype="float16",
accum_dtype="float",
out_dtype="float32"):
splitK = K // split_k splitK = K // split_k
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype)
......
...@@ -39,7 +39,7 @@ total_tiles = num_block_m * num_block_n ...@@ -39,7 +39,7 @@ total_tiles = num_block_m * num_block_n
# Two-tile SK + DP # Two-tile SK + DP
streamk_tiles = total_tiles % streamk_programs streamk_tiles = total_tiles % streamk_programs
if (total_tiles - streamk_tiles > streamk_programs): # (total_tiles // total_programs > 1) if total_tiles - streamk_tiles > streamk_programs: # (total_tiles // total_programs > 1)
streamk_tiles += streamk_programs streamk_tiles += streamk_programs
blocking_tiles = total_tiles - streamk_tiles blocking_tiles = total_tiles - streamk_tiles
...@@ -135,7 +135,6 @@ def tl_matmul_streamk( ...@@ -135,7 +135,6 @@ def tl_matmul_streamk(
C: T.Tensor, C: T.Tensor,
C_local: T.LocalBuffer, C_local: T.LocalBuffer,
): ):
for p in T.serial(sm_patition_factor): for p in T.serial(sm_patition_factor):
tile_id = pid + streamk_tiles + p * total_sm tile_id = pid + streamk_tiles + p * total_sm
pid_m = tile_id // T.ceildiv(N, block_N) pid_m = tile_id // T.ceildiv(N, block_N)
...@@ -150,12 +149,11 @@ def tl_matmul_streamk( ...@@ -150,12 +149,11 @@ def tl_matmul_streamk(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, dtypeAB), A: T.Tensor(A_shape, dtypeAB),
B: T.Tensor(B_shape, dtypeAB), B: T.Tensor(B_shape, dtypeAB),
C: T.Tensor((M, N), dtypeC), C: T.Tensor((M, N), dtypeC),
): ):
with T.Kernel(streamk_programs, threads=threads) as pid: with T.Kernel(streamk_programs, threads=threads) as pid:
A_shared = T.alloc_shared(A_shared_shape, dtypeAB) A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB) B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB) A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB)
......
...@@ -20,12 +20,11 @@ def naive_gemv( ...@@ -20,12 +20,11 @@ def naive_gemv(
dtype: str = "float16", dtype: str = "float16",
accum_dtype: str = "float", accum_dtype: str = "float",
): ):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((K,), dtype), A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn: with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn:
tn = T.get_thread_binding(0) # tn = threadIdx.x tn = T.get_thread_binding(0) # tn = threadIdx.x
...@@ -38,8 +37,7 @@ def naive_gemv( ...@@ -38,8 +37,7 @@ def naive_gemv(
A_shared[tk] = A[bk * BLOCK_K + tk] A_shared[tk] = A[bk * BLOCK_K + tk]
B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
for tk in T.serial(BLOCK_K): for tk in T.serial(BLOCK_K):
C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, tk].astype(accum_dtype)
tk].astype(accum_dtype)
C[bn * BLOCK_N + tn] = C_reg[0] C[bn * BLOCK_N + tn] = C_reg[0]
return main return main
...@@ -54,12 +52,11 @@ def naive_splitk_gemv( ...@@ -54,12 +52,11 @@ def naive_splitk_gemv(
dtype: str = "float16", dtype: str = "float16",
accum_dtype: str = "float", accum_dtype: str = "float",
): ):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((K,), dtype), A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn: with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn:
tn = T.get_thread_binding(0) tn = T.get_thread_binding(0)
...@@ -95,9 +92,9 @@ def splitk_gemv( ...@@ -95,9 +92,9 @@ def splitk_gemv(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((K,), dtype), A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0) tn = T.get_thread_binding(0)
...@@ -136,9 +133,9 @@ def splitk_gemv_vectorized( ...@@ -136,9 +133,9 @@ def splitk_gemv_vectorized(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((K,), dtype), A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0) tn = T.get_thread_binding(0)
...@@ -177,9 +174,9 @@ def splitk_gemv_vectorized_tvm( ...@@ -177,9 +174,9 @@ def splitk_gemv_vectorized_tvm(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((K,), dtype), A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0) tn = T.get_thread_binding(0)
...@@ -197,9 +194,9 @@ def splitk_gemv_vectorized_tvm( ...@@ -197,9 +194,9 @@ def splitk_gemv_vectorized_tvm(
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr( with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope", "reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"),
): ):
T.evaluate( T.evaluate(
T.tvm_thread_allreduce( T.tvm_thread_allreduce(
...@@ -209,7 +206,8 @@ def splitk_gemv_vectorized_tvm( ...@@ -209,7 +206,8 @@ def splitk_gemv_vectorized_tvm(
C_reduced[0], C_reduced[0],
tk, tk,
dtype="handle", dtype="handle",
)) )
)
C[bn * BLOCK_N + tn] = C_reduced[0] C[bn * BLOCK_N + tn] = C_reduced[0]
...@@ -218,10 +216,8 @@ def splitk_gemv_vectorized_tvm( ...@@ -218,10 +216,8 @@ def splitk_gemv_vectorized_tvm(
def get_block_template_configs(): def get_block_template_configs():
iter_params = dict( iter_params = dict(
block_M=[2, 4, 8, 32, 64, 128], block_M=[2, 4, 8, 32, 64, 128], block_N=[2, 4, 8, 32, 64, 128], num_stages=[0, 1, 2, 3, 4], threads=[32, 64, 128, 256]
block_N=[2, 4, 8, 32, 64, 128], )
num_stages=[0, 1, 2, 3, 4],
threads=[32, 64, 128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
...@@ -237,18 +233,9 @@ def get_block_template_configs(): ...@@ -237,18 +233,9 @@ def get_block_template_configs():
}, },
out_idx=[2], out_idx=[2],
) )
def gemv_alloc_reducer(M, def gemv_alloc_reducer(M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: str = "float16", accum_dtype: str = "float"):
N,
block_M=128,
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16",
accum_dtype: str = "float"):
@T.prim_func @T.prim_func
def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore
dtype)): # type: ignore
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m: with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m:
o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all") o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all")
T.clear(o_reducer) T.clear(o_reducer)
...@@ -295,9 +282,9 @@ def get_autotuned_kernel( ...@@ -295,9 +282,9 @@ def get_autotuned_kernel(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((K,), dtype), A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0) tn = T.get_thread_binding(0)
...@@ -315,9 +302,9 @@ def get_autotuned_kernel( ...@@ -315,9 +302,9 @@ def get_autotuned_kernel(
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr( with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope", "reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"),
): ):
T.evaluate( T.evaluate(
T.tvm_thread_allreduce( T.tvm_thread_allreduce(
...@@ -327,7 +314,8 @@ def get_autotuned_kernel( ...@@ -327,7 +314,8 @@ def get_autotuned_kernel(
C_reduced[0], C_reduced[0],
tk, tk,
dtype="handle", dtype="handle",
)) )
)
C[bn * BLOCK_N + tn] = C_reduced[0] C[bn * BLOCK_N + tn] = C_reduced[0]
...@@ -355,8 +343,7 @@ def main(do_bench: bool = True): ...@@ -355,8 +343,7 @@ def main(do_bench: bool = True):
check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench) check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench) check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench) check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench)
check_correctness_and_bench( check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench)
gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench)
print("Test passed!") print("Test passed!")
......
...@@ -5,21 +5,8 @@ import tilelang ...@@ -5,21 +5,8 @@ import tilelang
import tilelang.language as T import tilelang.language as T
@tilelang.jit( @tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
out_idx=[2], pass_configs={ def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def grouped_gemm_fwd(batch_sum,
batch_count,
K,
N,
block_M,
block_N,
block_K,
num_stages=2,
threads=128,
dtype="float16"):
""" """
args: args:
a (torch.Tensor): Input tensor of shape (M, K). a (torch.Tensor): Input tensor of shape (M, K).
...@@ -29,17 +16,14 @@ def grouped_gemm_fwd(batch_sum, ...@@ -29,17 +16,14 @@ def grouped_gemm_fwd(batch_sum,
@T.prim_func @T.prim_func
def kernel( def kernel(
A: T.Tensor([batch_sum, K], dtype), # type: ignore A: T.Tensor([batch_sum, K], dtype), # type: ignore
B: T.Tensor([batch_count, K, N], dtype), # type: ignore B: T.Tensor([batch_count, K, N], dtype), # type: ignore
C: T.Tensor([batch_sum, N], dtype), # type: ignore C: T.Tensor([batch_sum, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore
): ):
with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by):
with T.Kernel(
T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N),
threads=threads) as (bx, by):
A_shared = T.alloc_shared([block_M, block_K], dtype) A_shared = T.alloc_shared([block_M, block_K], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
...@@ -49,23 +33,17 @@ def grouped_gemm_fwd(batch_sum, ...@@ -49,23 +33,17 @@ def grouped_gemm_fwd(batch_sum,
m_start_padded = bx * block_M m_start_padded = bx * block_M
for i in range(batch_count): for i in range(batch_count):
in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i]) in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i]
cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0])
cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] cur_batch_size[0] = batch_sizes[cur_batch_idx[0]]
m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[ m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]]
cur_batch_idx[0]] actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
actual_rows = T.max(
0,
T.min(block_M,
cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared) T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared)
T.copy( T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared)
B[cur_batch_idx[0], k * block_K:(k + 1) * block_K,
by * block_N:(by + 1) * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local) T.gemm(A_shared, B_shared, C_local)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -76,7 +54,6 @@ def grouped_gemm_fwd(batch_sum, ...@@ -76,7 +54,6 @@ def grouped_gemm_fwd(batch_sum,
class _GroupedGEMM(torch.autograd.Function): class _GroupedGEMM(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, a, b, batch_sizes): def forward(ctx, a, b, batch_sizes):
block_M = 64 block_M = 64
...@@ -99,15 +76,11 @@ class _GroupedGEMM(torch.autograd.Function): ...@@ -99,15 +76,11 @@ class _GroupedGEMM(torch.autograd.Function):
for i in range(batch_count - 1): for i in range(batch_count - 1):
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i]) batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i])
for i in range(batch_count - 1): for i in range(batch_count - 1):
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes[i] + 1) / padding_M) * padding_M)
math.ceil((batch_sizes[i] + 1) / padding_M) *
padding_M)
batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32) batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32)
batch_padded_offsets = torch.tensor( batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=a.device, dtype=torch.int32)
batch_padded_offsets_list, device=a.device, dtype=torch.int32)
kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages, threads)
num_stages, threads)
o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets) o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets)
ctx.save_for_backward(a, b, batch_sizes, batch_offsets) ctx.save_for_backward(a, b, batch_sizes, batch_offsets)
...@@ -135,8 +108,7 @@ class _GroupedGEMM(torch.autograd.Function): ...@@ -135,8 +108,7 @@ class _GroupedGEMM(torch.autograd.Function):
return x return x
A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)] A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)]
kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, num_stages, threads)
num_stages, threads)
dB = kernel(A, grad_output, batch_sizes, batch_offsets) dB = kernel(A, grad_output, batch_sizes, batch_offsets)
return None, dB, None return None, dB, None
...@@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): ...@@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for i in range(batch_count - 1): for i in range(batch_count - 1):
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i])
for i in range(batch_count - 1): for i in range(batch_count - 1):
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_M)
math.ceil((batch_sizes_list[i] + 1) / padding_M) *
padding_M)
A = torch.randn(batch_sum, K, device=device, dtype=dtype) A = torch.randn(batch_sum, K, device=device, dtype=dtype)
B = torch.randn(batch_count, K, M, device=device, dtype=dtype) B = torch.randn(batch_count, K, M, device=device, dtype=dtype)
C = torch.empty(batch_sum, M, device=device, dtype=dtype) C = torch.empty(batch_sum, M, device=device, dtype=dtype)
...@@ -187,21 +157,8 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): ...@@ -187,21 +157,8 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets
@tilelang.jit( @tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
out_idx=[2], pass_configs={ def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def grouped_gemm_bwd(batch_sum,
batch_count,
M,
N,
block_M,
block_N,
block_K,
num_stages=2,
threads=128,
dtype="float16"):
""" """
args: args:
a (torch.Tensor): Input tensor of shape (M, K). a (torch.Tensor): Input tensor of shape (M, K).
...@@ -211,16 +168,13 @@ def grouped_gemm_bwd(batch_sum, ...@@ -211,16 +168,13 @@ def grouped_gemm_bwd(batch_sum,
@T.prim_func @T.prim_func
def kernel( def kernel(
A: T.Tensor([batch_sum, M], dtype), # type: ignore A: T.Tensor([batch_sum, M], dtype), # type: ignore
B: T.Tensor([batch_sum, N], dtype), # type: ignore B: T.Tensor([batch_sum, N], dtype), # type: ignore
C: T.Tensor([batch_count, M, N], dtype), # type: ignore C: T.Tensor([batch_count, M, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
): ):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz):
with T.Kernel(
T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count,
threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared([block_K, block_M], dtype) A_shared = T.alloc_shared([block_K, block_M], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
...@@ -228,13 +182,9 @@ def grouped_gemm_bwd(batch_sum, ...@@ -228,13 +182,9 @@ def grouped_gemm_bwd(batch_sum,
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages):
for i, j in T.Parallel(block_K, block_M): for i, j in T.Parallel(block_K, block_M):
A_shared[i, j] = T.if_then_else( A_shared[i, j] = T.if_then_else(i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, bx * block_M + j], 0)
i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i,
bx * block_M + j], 0)
for i, j in T.Parallel(block_K, block_N): for i, j in T.Parallel(block_K, block_N):
B_shared[i, j] = T.if_then_else( B_shared[i, j] = T.if_then_else(i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, by * block_N + j], 0)
i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i,
by * block_N + j], 0)
T.gemm(A_shared, B_shared, C_local, transpose_A=True) T.gemm(A_shared, B_shared, C_local, transpose_A=True)
T.copy(C_local, C[bz, bx * block_M, by * block_N]) T.copy(C_local, C[bz, bx * block_M, by * block_N])
...@@ -242,23 +192,12 @@ def grouped_gemm_bwd(batch_sum, ...@@ -242,23 +192,12 @@ def grouped_gemm_bwd(batch_sum,
return kernel return kernel
def run_tilelang_grouped_gemm(batch_sizes_list, def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False):
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages=2,
threads=128,
profile=False):
padding_M = block_M padding_M = block_M
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 dtype = torch.float16
A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs( A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, False, padding_M, device, dtype)
batch_sizes_list, K, M, False, padding_M, device, dtype)
A.requires_grad_(False) A.requires_grad_(False)
B.requires_grad_(True) B.requires_grad_(True)
...@@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, ...@@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
O.backward(dO, retain_graph=True) O.backward(dO, retain_graph=True)
dB, B.grad = B.grad.clone(), None dB, B.grad = B.grad.clone(), None
if ( if torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2):
torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and \
torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2)
):
print("✅ Tilelang and Torch match") print("✅ Tilelang and Torch match")
else: else:
print("❌ Tilelang and Torch mismatch") print("❌ Tilelang and Torch mismatch")
...@@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list, ...@@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes")
'--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes') parser.add_argument("--K", type=int, default=8192, help="reduce dim")
parser.add_argument('--K', type=int, default=8192, help='reduce dim') parser.add_argument("--M", type=int, default=8192, help="output dim")
parser.add_argument('--M', type=int, default=8192, help='output dim') parser.add_argument("--trans_b", action="store_true", help="transpose B")
parser.add_argument('--trans_b', action="store_true", help="transpose B") parser.add_argument("--profile", action="store_true", help="profile")
parser.add_argument('--profile', action="store_true", help="profile")
args = parser.parse_args() args = parser.parse_args()
batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")]
...@@ -301,14 +236,4 @@ if __name__ == "__main__": ...@@ -301,14 +236,4 @@ if __name__ == "__main__":
num_stages = 2 num_stages = 2
threads = 256 threads = 256
run_tilelang_grouped_gemm( run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile)
batch_sizes_list,
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages,
threads,
profile=args.profile)
...@@ -18,8 +18,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): ...@@ -18,8 +18,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
torch.Tensor: Resulting tensor after grouped matrix multiplication. torch.Tensor: Resulting tensor after grouped matrix multiplication.
""" """
assert a.shape[0] == sum(batch_sizes), "Sum of batch_sizes must equal the first dimension of a" assert a.shape[0] == sum(batch_sizes), "Sum of batch_sizes must equal the first dimension of a"
assert b.shape[0] == len( assert b.shape[0] == len(batch_sizes), "The first dimension of b must match the length of batch_sizes"
batch_sizes), "The first dimension of b must match the length of batch_sizes"
# Initialize output tensor # Initialize output tensor
output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype) output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype)
...@@ -38,15 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): ...@@ -38,15 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def grouped_gemm(batch_sizes_list, def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
K,
N,
block_M,
block_N,
block_K,
num_stages=2,
threads=128,
dtype="float16"):
""" """
args: args:
a (torch.Tensor): Input tensor of shape (M, K). a (torch.Tensor): Input tensor of shape (M, K).
...@@ -59,14 +50,13 @@ def grouped_gemm(batch_sizes_list, ...@@ -59,14 +50,13 @@ def grouped_gemm(batch_sizes_list,
@T.prim_func @T.prim_func
def kernel( def kernel(
A: T.Tensor([batch_sum, K], dtype), # type: ignore A: T.Tensor([batch_sum, K], dtype), # type: ignore
B: T.Tensor([batch_count, K, N], dtype), # type: ignore B: T.Tensor([batch_count, K, N], dtype), # type: ignore
C: T.Tensor([batch_sum, N], dtype), # type: ignore C: T.Tensor([batch_sum, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore
): ):
with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by): with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared([block_M, block_K], dtype) A_shared = T.alloc_shared([block_M, block_K], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype)
...@@ -77,23 +67,17 @@ def grouped_gemm(batch_sizes_list, ...@@ -77,23 +67,17 @@ def grouped_gemm(batch_sizes_list,
m_start_padded = bx * block_M m_start_padded = bx * block_M
for i in range(batch_count): for i in range(batch_count):
in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i]) in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i]
cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0])
cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] cur_batch_size[0] = batch_sizes[cur_batch_idx[0]]
m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[ m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]]
cur_batch_idx[0]] actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
actual_rows = T.max(
0,
T.min(block_M,
cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared) T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared)
T.copy( T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared)
B[cur_batch_idx[0], k * block_K:(k + 1) * block_K,
by * block_N:(by + 1) * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local) T.gemm(A_shared, B_shared, C_local)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -111,8 +95,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): ...@@ -111,8 +95,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for i in range(batch_count - 1): for i in range(batch_count - 1):
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i])
for i in range(batch_count - 1): for i in range(batch_count - 1):
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i]) / padding_M) * padding_M)
math.ceil((batch_sizes_list[i]) / padding_M) * padding_M)
A = torch.randn(batch_sum, K, device=device, dtype=dtype) A = torch.randn(batch_sum, K, device=device, dtype=dtype)
B = torch.randn(batch_count, K, M, device=device, dtype=dtype) B = torch.randn(batch_count, K, M, device=device, dtype=dtype)
C = torch.empty(batch_sum, M, device=device, dtype=dtype) C = torch.empty(batch_sum, M, device=device, dtype=dtype)
...@@ -125,27 +108,16 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): ...@@ -125,27 +108,16 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets
def run_tilelang_grouped_gemm(batch_sizes_list, def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False):
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages=2,
threads=128,
profile=False):
padding_M = block_M padding_M = block_M
batch_sum = sum(batch_sizes_list) batch_sum = sum(batch_sizes_list)
kernel = grouped_gemm( kernel = grouped_gemm(tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads)
tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads)
# print(kernel.get_kernel_source()) # print(kernel.get_kernel_source())
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 dtype = torch.float16
A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs( A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype)
batch_sizes_list, K, M, trans_b, padding_M, device, dtype)
out = kernel(A, B, batch_sizes, batch_offsets, batch_padded_offsets) out = kernel(A, B, batch_sizes, batch_offsets, batch_padded_offsets)
ref_output = torch_gmm(A, B, batch_sizes, batch_offsets, trans_b) ref_output = torch_gmm(A, B, batch_sizes, batch_offsets, trans_b)
# print(out) # print(out)
...@@ -157,8 +129,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, ...@@ -157,8 +129,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if profile: if profile:
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
latency = profiler.do_bench( latency = profiler.do_bench(warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets])
warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets])
print(f"Latency: {latency} ms") print(f"Latency: {latency} ms")
print(f"TFlops: {batch_sum * K * M * 2 / latency * 1e-9} TFlops") print(f"TFlops: {batch_sum * K * M * 2 / latency * 1e-9} TFlops")
...@@ -173,12 +144,11 @@ def test_grouped_gemm(): ...@@ -173,12 +144,11 @@ def test_grouped_gemm():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes")
'--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes') parser.add_argument("--K", type=int, default=8192, help="reduce dim")
parser.add_argument('--K', type=int, default=8192, help='reduce dim') parser.add_argument("--M", type=int, default=8192, help="output dim")
parser.add_argument('--M', type=int, default=8192, help='output dim') parser.add_argument("--trans_b", action="store_true", help="transpose B")
parser.add_argument('--trans_b', action="store_true", help="transpose B") parser.add_argument("--profile", action="store_true", help="profile")
parser.add_argument('--profile', action="store_true", help="profile")
args = parser.parse_args() args = parser.parse_args()
batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")]
...@@ -190,14 +160,4 @@ if __name__ == "__main__": ...@@ -190,14 +160,4 @@ if __name__ == "__main__":
num_stages = 2 num_stages = 2
threads = 256 threads = 256
run_tilelang_grouped_gemm( run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile)
batch_sizes_list,
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages,
threads,
profile=args.profile)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment