Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -4,7 +4,6 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
......
......@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20):
num_stages,
thread_num,
enable_rasterization,
))
)
)
configs = [
{
......@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20):
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
}
for c in _configs
]
return configs
def get_best_config(M, N, K, with_roller=False):
def kernel(
block_M=None,
block_N=None,
......@@ -124,8 +125,7 @@ def get_best_config(M, N, K, with_roller=False):
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False):
return main
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
autotuner = (
AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller))
.set_compile_args(
out_idx=[-1],
target="auto",
).set_profile_args(
)
.set_profile_args(
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
)
)
return autotuner.run(warmup=3, rep=20)
......@@ -167,47 +170,15 @@ def get_heuristic_config() -> dict:
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version in {80}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 2,
"thread_num": 128,
"enable_rasteration": True
}
return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True}
elif sm_version in {90}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 64,
"num_stages": 3,
"thread_num": 256,
"enable_rasteration": True
}
return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True}
else:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 0,
"thread_num": 128,
"enable_rasteration": True
}
return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True}
@tl.jit(out_idx=[-1])
def matmul(M,
N,
K,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"):
@T.prim_func
def gemm_autotune(
A: T.Tensor((M, K), dtype),
......@@ -236,11 +207,7 @@ def matmul(M,
return gemm_autotune
def main(M: int = 4096,
N: int = 4096,
K: int = 4096,
use_autotune: bool = False,
with_roller: bool = False):
def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False):
use_autotune = True
if use_autotune:
result = get_best_config(M, N, K, with_roller)
......@@ -266,15 +233,7 @@ if __name__ == "__main__":
parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K")
parser.add_argument(
"--use_autotune",
action="store_true",
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=False,
help="Whether to enable BitBLAS roller for search space")
parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs")
parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
main(args.m, args.n, args.k, args.use_autotune, args.with_roller)
......@@ -4,7 +4,8 @@ import tilelang
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
......@@ -104,7 +105,6 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -112,10 +112,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -123,7 +125,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -133,7 +134,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
......
......@@ -5,17 +5,7 @@ import argparse
@tilelang.jit(out_idx=[-1])
def matmul_non_persistent(M,
N,
K,
block_M,
block_N,
block_K,
threads,
num_stages,
dtype="float16",
accum_dtype="float"):
def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -43,18 +33,9 @@ def matmul_non_persistent(M,
@tilelang.jit(out_idx=[-1])
def matmul_persistent(M,
N,
K,
block_M,
block_N,
block_K,
threads,
num_stages,
dtype="float16",
accum_dtype="float",
use_persistent_primitive=True):
def matmul_persistent(
M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float", use_persistent_primitive=True
):
sm_num = driver.get_num_sms()
m_blocks = T.ceildiv(M, block_M)
n_blocks = T.ceildiv(N, block_N)
......@@ -100,8 +81,7 @@ def matmul_persistent(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
for bx, by in T.Persistent(
[T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id):
for bx, by in T.Persistent([T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id):
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared)
......@@ -128,18 +108,15 @@ def main(M=4096, N=4096, K=4096):
num_stages = 3
persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
persistent_profiler = persistent_kernel.get_profiler(
tensor_supply_type=tilelang.TensorSupplyType.Randn)
persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Persistent GEMM: All check passed.")
persistent_latency = persistent_profiler.do_bench(warmup=500)
print(f"Persistent GEMM Latency: {persistent_latency} ms")
print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops")
non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads,
num_stages)
non_persistent_profiler = non_persistent_kernel.get_profiler(
tensor_supply_type=tilelang.TensorSupplyType.Randn)
non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
non_persistent_profiler = non_persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Non-Persistent GEMM: All check passed.")
non_persistent_latency = non_persistent_profiler.do_bench(warmup=500)
......@@ -151,9 +128,9 @@ def main(M=4096, N=4096, K=4096):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=8192, help='M dimension')
parser.add_argument('--N', type=int, default=8192, help='N dimension')
parser.add_argument('--K', type=int, default=8192, help='K dimension')
parser.add_argument("--M", type=int, default=8192, help="M dimension")
parser.add_argument("--N", type=int, default=8192, help="N dimension")
parser.add_argument("--K", type=int, default=8192, help="K dimension")
args = parser.parse_args()
M, N, K = args.M, args.N, args.K
main(M, N, K)
......@@ -4,7 +4,6 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def gemm_schedule(
A: T.Tensor((M, K), dtype),
......
......@@ -17,10 +17,8 @@ def supply_prog(args):
a_param, b_param = args
M, K = a_param.shape
N, _ = b_param.shape
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
return [a, b]
......@@ -35,10 +33,9 @@ def get_configs():
valid_configs = []
for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks,
num_stages, num_threads, k_packs,
gemm_types):
valid_configs.append({
for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types):
valid_configs.append(
{
"block_M": m,
"block_N": n,
"block_K": k,
......@@ -46,16 +43,14 @@ def get_configs():
"num_threads": t,
"k_pack": kp,
"gemm_type": gemm_type,
})
}
)
return valid_configs
@tilelang.autotune(
configs=get_configs(),
cache_input_tensors=True,
ref_prog=ref_program,
manual_check_prog=manual_check_prog,
supply_prog=supply_prog)
configs=get_configs(), cache_input_tensors=True, ref_prog=ref_program, manual_check_prog=manual_check_prog, supply_prog=supply_prog
)
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = "float8_e4m3fnuz"
......@@ -67,8 +62,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -77,13 +71,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_local)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_local,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N])
......@@ -93,8 +81,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -103,13 +90,7 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N])
......@@ -123,10 +104,8 @@ def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pa
def test_gemm_fp8(M, N, K):
kernel = fp8_matmul(M, N, K)
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
c = kernel(a, b)
ref_c = ref_program(a, b)
torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
......
......@@ -13,7 +13,6 @@ def calc_diff(x, y):
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
@T.prim_func
def gemm_fp8(
A: T.Tensor((M, K), dtype),
......@@ -41,8 +40,8 @@ def test_gemm_fp8(M, N, K, dtype):
kernel = matmul(M, N, K, 128, 128, 64, dtype)
a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
a = torch.randn(M, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype)
b = torch.randn(N, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype)
c = kernel(a, b)
......@@ -57,8 +56,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main():
test_gemm_fp8(1024, 1024, 1024, 'float8_e4m3')
test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2')
test_gemm_fp8(1024, 1024, 1024, "float8_e4m3")
test_gemm_fp8(1024, 1024, 1024, "float8_e5m2")
if __name__ == "__main__":
......
......@@ -59,14 +59,14 @@ def test_gemm_fp8(M, N, K, dtype):
kernel = matmul(M, N, K, 128, 128, 64, dtype)
a = torch.rand(M, K, dtype=torch.float16, device='cuda')
a = torch.rand(M, K, dtype=torch.float16, device="cuda")
a = (100 * (2 * a - 1)).to(dtype=torch_dtype)
b = torch.rand(N, K, dtype=torch.float16, device='cuda')
b = torch.rand(N, K, dtype=torch.float16, device="cuda")
b = (100 * (2 * b - 1)).to(dtype=torch_dtype)
c = kernel(a, b)
ref_c = (a.float() @ b.float().T)
ref_c = a.float() @ b.float().T
diff = calc_diff(c, ref_c)
print(f"diff: {diff}")
......@@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main():
test_gemm_fp8(1024, 1024, 8192, 'float8_e4m3')
test_gemm_fp8(1024, 1024, 8192, 'float8_e5m2')
test_gemm_fp8(1024, 1024, 8192, "float8_e4m3")
test_gemm_fp8(1024, 1024, 8192, "float8_e5m2")
if __name__ == "__main__":
......
......@@ -5,7 +5,8 @@ from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type
......@@ -115,7 +116,6 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -123,10 +123,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -134,7 +136,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -144,7 +145,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......
......@@ -121,6 +121,4 @@ for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]:
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms")
print(
f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS"
)
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS")
......@@ -5,7 +5,6 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -62,7 +61,8 @@ jit_kernel = tilelang.compile(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
......
......@@ -40,15 +40,7 @@ def matmul(
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
......@@ -66,8 +58,7 @@ in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
num_stages = 2
threads = 256
func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
......@@ -75,7 +66,8 @@ jit_kernel = tilelang.compile(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
print(jit_kernel.get_kernel_source())
......@@ -88,4 +80,4 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS")
print(f"Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS")
......@@ -17,77 +17,76 @@ torch.manual_seed(42)
DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 64,
'num_stages': 1,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
"float": {
"block_M": 128,
"block_N": 64,
"block_K": 64,
"num_stages": 1,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
"float16": {
"block_M": 256,
"block_N": 128,
"block_K": 64,
"num_stages": 2,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
'float16': {
'block_M': 256,
'block_N': 128,
'block_K': 64,
'num_stages': 2,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
},
"h20": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
"float": {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
"float16": {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
},
'float16': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}
}
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages,
thread_num, policy, enable_rasterization, use_cutlass_layout):
def matmul_sp_fp16_custom_compress(
M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout
):
e_factor, e_dtype = (16, "int16")
@T.prim_func
def gemm_sp_fp16_custom_compress(
A_sparse: T.Tensor((M, K // 2), 'float16'),
A_sparse: T.Tensor((M, K // 2), "float16"),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), 'float16'),
B: T.Tensor((K, N), "float16"),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16')
A_shared = T.alloc_shared((block_M, block_K // 2), "float16")
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), 'float16')
B_shared = T.alloc_shared((block_K, block_N), "float16")
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if use_cutlass_layout:
T.annotate_layout({
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
})
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
}
)
T.clear(C_local)
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
......@@ -108,8 +107,7 @@ def torch_compress(dense):
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
"""
if dense.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor")
raise RuntimeError(f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor")
m, k = dense.shape
......@@ -131,9 +129,7 @@ def torch_compress(dense):
if m % 32 != 0:
raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32")
if k % (4 * quadbits_per_meta_elem) != 0:
raise RuntimeError(
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
)
raise RuntimeError(f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}")
if dense.dtype != torch.float:
ksparse = 4
......@@ -194,19 +190,13 @@ def torch_compress(dense):
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else:
sparse = dense_2.gather(-1,
idxs0.unsqueeze(-1) // 2).view(
m, k // 2) # type: ignore[possibly-undefined]
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
if quadbits_per_meta_elem == 4:
meta = (
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12))
meta = meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12)
elif quadbits_per_meta_elem == 8:
meta = (
meta_n[:, :, 0]
......@@ -216,7 +206,8 @@ def torch_compress(dense):
| (meta_n[:, :, 4] << 16)
| (meta_n[:, :, 5] << 20)
| (meta_n[:, :, 6] << 24)
| (meta_n[:, :, 7] << 28))
| (meta_n[:, :, 7] << 28)
)
return (sparse, meta)
......@@ -234,9 +225,11 @@ def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
@tilelang.jit(
out_idx=[1, 2], pass_configs={
out_idx=[1, 2],
pass_configs={
tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
})
},
)
def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
e_factor, e_dtype = ARCH_INFO["8.0"]
e_K = K // e_factor
......@@ -258,14 +251,12 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
if use_cutlass_layout:
T.annotate_layout({
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
})
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
}
)
T.clear(A_sp_shared)
T.clear(E_shared)
# TODO: alloc_var seems buggy here
......@@ -295,8 +286,7 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
non_zero_elt_log_idx[1] = 3
for i in T.serial(elem):
val = non_zero_elt_log_idx[i]
E_shared[tm, a_k // e_factor] |= T.shift_left(
val, 4 * (g_i % (e_factor // group)) + 2 * i)
E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i)
T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
......@@ -304,41 +294,27 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--use_cutlass_layout", action='store_true', help="Use cutlass layout for E tensor")
parser.add_argument(
"--use_torch_compressor", action='store_true', help="Use torch sparse for reference")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor")
parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference")
parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16_custom_compress(
args.m,
args.n,
args.k,
args.accum_dtype,
**DEFAULT_CONFIG[args.cfg][args.accum_dtype],
use_cutlass_layout=args.use_cutlass_layout)
args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype], use_cutlass_layout=args.use_cutlass_layout
)
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half)
b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half)
if args.use_torch_compressor:
assert not args.use_cutlass_layout, "torch sparse must be used with naive layout"
a_sparse, e = torch_compress(a)
else:
a_sparse, e = compress_kernel(
args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(
a)
a_sparse, e = compress_kernel(args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(a)
c = kernel(a_sparse, e, b)
......@@ -346,9 +322,7 @@ def main():
assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3)
print(
f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}"
)
print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}")
latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b)
......@@ -356,8 +330,8 @@ def main():
total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s")
if __name__ == "__main__":
......
......@@ -16,80 +16,77 @@ arch = nvcc.get_target_compute_version()
DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 64,
'num_stages': 1,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
"float": {
"block_M": 128,
"block_N": 64,
"block_K": 64,
"num_stages": 1,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
"float16": {
"block_M": 256,
"block_N": 128,
"block_K": 64,
"num_stages": 2,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
'float16': {
'block_M': 256,
'block_N': 128,
'block_K': 64,
'num_stages': 2,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
},
"h20": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
"float": {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
"float16": {
"block_M": 128,
"block_N": 64,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
},
'float16': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}
}
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
enable_rasterization):
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization):
e_factor, e_dtype = ARCH_INFO[arch]
@T.prim_func
def gemm_sp_fp16(
A_sparse: T.Tensor((M, K // 2), 'float16'),
A_sparse: T.Tensor((M, K // 2), "float16"),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), 'float16'),
B: T.Tensor((K, N), "float16"),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16')
A_shared = T.alloc_shared((block_M, block_K // 2), "float16")
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), 'float16')
B_shared = T.alloc_shared((block_K, block_N), "float16")
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", block_k=block_K, arch=arch),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype="float16", block_k=block_K, arch=arch),
})
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype="float16", block_k=block_K, arch=arch),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", block_k=block_K, arch=arch),
}
)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
......@@ -107,25 +104,15 @@ def main():
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype,
**DEFAULT_CONFIG[args.cfg][args.accum_dtype])
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype])
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half)
b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half)
a_sparse, e = compress(
a,
transposed=False,
block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]['block_K'],
arch=arch)
a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]["block_K"], arch=arch)
c = kernel(a_sparse, e, b)
ref_c = a @ b
......@@ -140,8 +127,8 @@ def main():
total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s")
if __name__ == "__main__":
......
......@@ -3,17 +3,7 @@ import tilelang.language as T
@tilelang.jit
def matmul(M,
N,
K,
block_M,
block_N,
block_K,
split_k,
dtype="float16",
accum_dtype="float",
out_dtype="float32"):
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"):
splitK = K // split_k
@T.prim_func
......@@ -22,8 +12,7 @@ def matmul(M,
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
......
......@@ -3,17 +3,7 @@ import tilelang.language as T
@tilelang.jit
def matmul(M,
N,
K,
block_M,
block_N,
block_K,
split_k,
dtype="float16",
accum_dtype="float",
out_dtype="float32"):
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"):
splitK = K // split_k
@T.prim_func
......@@ -22,8 +12,7 @@ def matmul(M,
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
......
......@@ -39,7 +39,7 @@ total_tiles = num_block_m * num_block_n
# Two-tile SK + DP
streamk_tiles = total_tiles % streamk_programs
if (total_tiles - streamk_tiles > streamk_programs): # (total_tiles // total_programs > 1)
if total_tiles - streamk_tiles > streamk_programs: # (total_tiles // total_programs > 1)
streamk_tiles += streamk_programs
blocking_tiles = total_tiles - streamk_tiles
......@@ -135,7 +135,6 @@ def tl_matmul_streamk(
C: T.Tensor,
C_local: T.LocalBuffer,
):
for p in T.serial(sm_patition_factor):
tile_id = pid + streamk_tiles + p * total_sm
pid_m = tile_id // T.ceildiv(N, block_N)
......@@ -155,7 +154,6 @@ def tl_matmul_streamk(
C: T.Tensor((M, N), dtypeC),
):
with T.Kernel(streamk_programs, threads=threads) as pid:
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB)
......
......@@ -20,7 +20,6 @@ def naive_gemv(
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Tensor((K,), dtype),
......@@ -38,8 +37,7 @@ def naive_gemv(
A_shared[tk] = A[bk * BLOCK_K + tk]
B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
for tk in T.serial(BLOCK_K):
C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn,
tk].astype(accum_dtype)
C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, tk].astype(accum_dtype)
C[bn * BLOCK_N + tn] = C_reg[0]
return main
......@@ -54,7 +52,6 @@ def naive_splitk_gemv(
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Tensor((K,), dtype),
......@@ -209,7 +206,8 @@ def splitk_gemv_vectorized_tvm(
C_reduced[0],
tk,
dtype="handle",
))
)
)
C[bn * BLOCK_N + tn] = C_reduced[0]
......@@ -218,10 +216,8 @@ def splitk_gemv_vectorized_tvm(
def get_block_template_configs():
iter_params = dict(
block_M=[2, 4, 8, 32, 64, 128],
block_N=[2, 4, 8, 32, 64, 128],
num_stages=[0, 1, 2, 3, 4],
threads=[32, 64, 128, 256])
block_M=[2, 4, 8, 32, 64, 128], block_N=[2, 4, 8, 32, 64, 128], num_stages=[0, 1, 2, 3, 4], threads=[32, 64, 128, 256]
)
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
......@@ -237,18 +233,9 @@ def get_block_template_configs():
},
out_idx=[2],
)
def gemv_alloc_reducer(M,
N,
block_M=128,
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16",
accum_dtype: str = "float"):
def gemv_alloc_reducer(M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: str = "float16", accum_dtype: str = "float"):
@T.prim_func
def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M,
dtype)): # type: ignore
def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m:
o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all")
T.clear(o_reducer)
......@@ -327,7 +314,8 @@ def get_autotuned_kernel(
C_reduced[0],
tk,
dtype="handle",
))
)
)
C[bn * BLOCK_N + tn] = C_reduced[0]
......@@ -355,8 +343,7 @@ def main(do_bench: bool = True):
check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench)
check_correctness_and_bench(
gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench)
check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench)
print("Test passed!")
......
......@@ -5,21 +5,8 @@ import tilelang
import tilelang.language as T
@tilelang.jit(
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def grouped_gemm_fwd(batch_sum,
batch_count,
K,
N,
block_M,
block_N,
block_K,
num_stages=2,
threads=128,
dtype="float16"):
@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
......@@ -36,10 +23,7 @@ def grouped_gemm_fwd(batch_sum,
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore
):
with T.Kernel(
T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N),
threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared([block_M, block_K], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
......@@ -49,23 +33,17 @@ def grouped_gemm_fwd(batch_sum,
m_start_padded = bx * block_M
for i in range(batch_count):
in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i])
in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i]
cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0])
cur_batch_size[0] = batch_sizes[cur_batch_idx[0]]
m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[
cur_batch_idx[0]]
actual_rows = T.max(
0,
T.min(block_M,
cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]]
actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared)
T.copy(
B[cur_batch_idx[0], k * block_K:(k + 1) * block_K,
by * block_N:(by + 1) * block_N], B_shared)
T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared)
T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
for i, j in T.Parallel(block_M, block_N):
......@@ -76,7 +54,6 @@ def grouped_gemm_fwd(batch_sum,
class _GroupedGEMM(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b, batch_sizes):
block_M = 64
......@@ -99,15 +76,11 @@ class _GroupedGEMM(torch.autograd.Function):
for i in range(batch_count - 1):
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i])
for i in range(batch_count - 1):
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] +
math.ceil((batch_sizes[i] + 1) / padding_M) *
padding_M)
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes[i] + 1) / padding_M) * padding_M)
batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32)
batch_padded_offsets = torch.tensor(
batch_padded_offsets_list, device=a.device, dtype=torch.int32)
batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=a.device, dtype=torch.int32)
kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K,
num_stages, threads)
kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages, threads)
o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets)
ctx.save_for_backward(a, b, batch_sizes, batch_offsets)
......@@ -135,8 +108,7 @@ class _GroupedGEMM(torch.autograd.Function):
return x
A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)]
kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K,
num_stages, threads)
kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, num_stages, threads)
dB = kernel(A, grad_output, batch_sizes, batch_offsets)
return None, dB, None
......@@ -172,9 +144,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for i in range(batch_count - 1):
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i])
for i in range(batch_count - 1):
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] +
math.ceil((batch_sizes_list[i] + 1) / padding_M) *
padding_M)
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_M)
A = torch.randn(batch_sum, K, device=device, dtype=dtype)
B = torch.randn(batch_count, K, M, device=device, dtype=dtype)
C = torch.empty(batch_sum, M, device=device, dtype=dtype)
......@@ -187,21 +157,8 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets
@tilelang.jit(
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def grouped_gemm_bwd(batch_sum,
batch_count,
M,
N,
block_M,
block_N,
block_K,
num_stages=2,
threads=128,
dtype="float16"):
@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
......@@ -217,10 +174,7 @@ def grouped_gemm_bwd(batch_sum,
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
):
with T.Kernel(
T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count,
threads=threads) as (bx, by, bz):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared([block_K, block_M], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
......@@ -228,13 +182,9 @@ def grouped_gemm_bwd(batch_sum,
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages):
for i, j in T.Parallel(block_K, block_M):
A_shared[i, j] = T.if_then_else(
i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i,
bx * block_M + j], 0)
A_shared[i, j] = T.if_then_else(i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, bx * block_M + j], 0)
for i, j in T.Parallel(block_K, block_N):
B_shared[i, j] = T.if_then_else(
i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i,
by * block_N + j], 0)
B_shared[i, j] = T.if_then_else(i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, by * block_N + j], 0)
T.gemm(A_shared, B_shared, C_local, transpose_A=True)
T.copy(C_local, C[bz, bx * block_M, by * block_N])
......@@ -242,23 +192,12 @@ def grouped_gemm_bwd(batch_sum,
return kernel
def run_tilelang_grouped_gemm(batch_sizes_list,
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages=2,
threads=128,
profile=False):
def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False):
padding_M = block_M
device = torch.device("cuda")
dtype = torch.float16
A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(
batch_sizes_list, K, M, False, padding_M, device, dtype)
A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, False, padding_M, device, dtype)
A.requires_grad_(False)
B.requires_grad_(True)
......@@ -273,10 +212,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
O.backward(dO, retain_graph=True)
dB, B.grad = B.grad.clone(), None
if (
torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and \
torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2)
):
if torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2):
print("✅ Tilelang and Torch match")
else:
print("❌ Tilelang and Torch mismatch")
......@@ -284,12 +220,11 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes')
parser.add_argument('--K', type=int, default=8192, help='reduce dim')
parser.add_argument('--M', type=int, default=8192, help='output dim')
parser.add_argument('--trans_b', action="store_true", help="transpose B")
parser.add_argument('--profile', action="store_true", help="profile")
parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes")
parser.add_argument("--K", type=int, default=8192, help="reduce dim")
parser.add_argument("--M", type=int, default=8192, help="output dim")
parser.add_argument("--trans_b", action="store_true", help="transpose B")
parser.add_argument("--profile", action="store_true", help="profile")
args = parser.parse_args()
batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")]
......@@ -301,14 +236,4 @@ if __name__ == "__main__":
num_stages = 2
threads = 256
run_tilelang_grouped_gemm(
batch_sizes_list,
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages,
threads,
profile=args.profile)
run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile)
......@@ -18,8 +18,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
torch.Tensor: Resulting tensor after grouped matrix multiplication.
"""
assert a.shape[0] == sum(batch_sizes), "Sum of batch_sizes must equal the first dimension of a"
assert b.shape[0] == len(
batch_sizes), "The first dimension of b must match the length of batch_sizes"
assert b.shape[0] == len(batch_sizes), "The first dimension of b must match the length of batch_sizes"
# Initialize output tensor
output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype)
......@@ -38,15 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
@tilelang.jit(out_idx=[2])
def grouped_gemm(batch_sizes_list,
K,
N,
block_M,
block_N,
block_K,
num_stages=2,
threads=128,
dtype="float16"):
def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
......@@ -66,7 +57,6 @@ def grouped_gemm(batch_sizes_list,
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore
):
with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared([block_M, block_K], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype)
......@@ -77,23 +67,17 @@ def grouped_gemm(batch_sizes_list,
m_start_padded = bx * block_M
for i in range(batch_count):
in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i])
in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i]
cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0])
cur_batch_size[0] = batch_sizes[cur_batch_idx[0]]
m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[
cur_batch_idx[0]]
actual_rows = T.max(
0,
T.min(block_M,
cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]]
actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared)
T.copy(
B[cur_batch_idx[0], k * block_K:(k + 1) * block_K,
by * block_N:(by + 1) * block_N], B_shared)
T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared)
T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
for i, j in T.Parallel(block_M, block_N):
......@@ -111,8 +95,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
for i in range(batch_count - 1):
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i])
for i in range(batch_count - 1):
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] +
math.ceil((batch_sizes_list[i]) / padding_M) * padding_M)
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i]) / padding_M) * padding_M)
A = torch.randn(batch_sum, K, device=device, dtype=dtype)
B = torch.randn(batch_count, K, M, device=device, dtype=dtype)
C = torch.empty(batch_sum, M, device=device, dtype=dtype)
......@@ -125,27 +108,16 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets
def run_tilelang_grouped_gemm(batch_sizes_list,
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages=2,
threads=128,
profile=False):
def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False):
padding_M = block_M
batch_sum = sum(batch_sizes_list)
kernel = grouped_gemm(
tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads)
kernel = grouped_gemm(tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads)
# print(kernel.get_kernel_source())
device = torch.device("cuda")
dtype = torch.float16
A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(
batch_sizes_list, K, M, trans_b, padding_M, device, dtype)
A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype)
out = kernel(A, B, batch_sizes, batch_offsets, batch_padded_offsets)
ref_output = torch_gmm(A, B, batch_sizes, batch_offsets, trans_b)
# print(out)
......@@ -157,8 +129,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
if profile:
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
latency = profiler.do_bench(
warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets])
latency = profiler.do_bench(warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets])
print(f"Latency: {latency} ms")
print(f"TFlops: {batch_sum * K * M * 2 / latency * 1e-9} TFlops")
......@@ -173,12 +144,11 @@ def test_grouped_gemm():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes')
parser.add_argument('--K', type=int, default=8192, help='reduce dim')
parser.add_argument('--M', type=int, default=8192, help='output dim')
parser.add_argument('--trans_b', action="store_true", help="transpose B")
parser.add_argument('--profile', action="store_true", help="profile")
parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes")
parser.add_argument("--K", type=int, default=8192, help="reduce dim")
parser.add_argument("--M", type=int, default=8192, help="output dim")
parser.add_argument("--trans_b", action="store_true", help="transpose B")
parser.add_argument("--profile", action="store_true", help="profile")
args = parser.parse_args()
batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")]
......@@ -190,14 +160,4 @@ if __name__ == "__main__":
num_stages = 2
threads = 256
run_tilelang_grouped_gemm(
batch_sizes_list,
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages,
threads,
profile=args.profile)
run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment