Commit 3db18726 authored by Cunxiao Ni's avatar Cunxiao Ni Committed by LeiWang1999
Browse files

[Example] Update examples to use @tilelang.jit (#597)



* [Example] Update kernel compilation in examples to use @tilelang.jit

- Refactored multiple examples to eliminate the use of `tilelang.compile` for kernel creation, directly invoking the functions instead.
- Added `@tilelang.jit` decorators with appropriate output indices to enhance performance and maintainability.
- Improved code clarity by simplifying the kernel invocation process across various examples, ensuring consistency in how kernels are defined and executed.

* format

* Update example_tilelang_sparse_gqa_decode_varlen_indice.py

* Update example_dequant_gemm_fine_grained.py

* Update example_gemm_autotune.py

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 18889821
...@@ -10,6 +10,7 @@ def ref_program(x, y): ...@@ -10,6 +10,7 @@ def ref_program(x, y):
return x + y return x + y
@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func @T.prim_func
...@@ -68,8 +69,7 @@ def main(): ...@@ -68,8 +69,7 @@ def main():
else: else:
# Default config # Default config
config = {"block_M": 128, "block_N": 256, "threads": 128} config = {"block_M": 128, "block_N": 256, "threads": 128}
kernel = tilelang.compile( kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32"), out_idx=-1)
out = kernel(a, b) out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
......
...@@ -65,6 +65,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -65,6 +65,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
...@@ -240,11 +241,10 @@ def main(batch: int = 1, ...@@ -240,11 +241,10 @@ def main(batch: int = 1,
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if (not tune):
program = flashattn( kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)( batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)(
block_M=64, block_N=64, num_stages=2, threads=128) block_M=64, block_N=64, num_stages=2, threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
......
...@@ -32,6 +32,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -32,6 +32,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
...@@ -214,11 +215,10 @@ def main( ...@@ -214,11 +215,10 @@ def main(
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if (not tune):
program = flashattn( kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)( batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)(
block_M=128, block_N=128, num_stages=2, threads=256) block_M=128, block_N=128, num_stages=2, threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
......
...@@ -32,6 +32,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -32,6 +32,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
...@@ -197,11 +198,10 @@ def main( ...@@ -197,11 +198,10 @@ def main(
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if (not tune):
program = flashattn( kernel = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)( batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)(
block_M=64, block_N=64, num_stages=1, threads=128) block_M=64, block_N=64, num_stages=1, threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
......
...@@ -32,6 +32,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -32,6 +32,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
...@@ -202,11 +203,10 @@ def main( ...@@ -202,11 +203,10 @@ def main(
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if (not tune):
program = flashattn( kernel = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)( batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)(
block_M=128, block_N=128, num_stages=2, threads=256) block_M=128, block_N=128, num_stages=2, threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
......
...@@ -30,6 +30,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -30,6 +30,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
...@@ -191,11 +192,10 @@ def main( ...@@ -191,11 +192,10 @@ def main(
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if (not tune):
program = flashattn( kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune)( batch, heads, seq_len, dim, is_causal, tune=tune)(
block_M=128, block_N=128, num_stages=1, threads=128) block_M=128, block_N=128, num_stages=1, threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
......
...@@ -30,6 +30,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -30,6 +30,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
...@@ -196,11 +197,10 @@ def main( ...@@ -196,11 +197,10 @@ def main(
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if (not tune):
program = flashattn( kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune)( batch, heads, seq_len, dim, is_causal, tune=tune)(
block_M=128, block_N=128, num_stages=2, threads=256) block_M=128, block_N=128, num_stages=2, threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
......
...@@ -232,6 +232,7 @@ def flashattn(batch_size, UQ, UKV, heads, dim, is_causal): ...@@ -232,6 +232,7 @@ def flashattn(batch_size, UQ, UKV, heads, dim, is_causal):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[6])
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.prim_func @T.prim_func
...@@ -400,8 +401,7 @@ def main(batch: int = 2, heads: int = 16, seq_len: int = 256, dim: int = 32): ...@@ -400,8 +401,7 @@ def main(batch: int = 2, heads: int = 16, seq_len: int = 256, dim: int = 32):
UK = k_unpad.shape[0] # unpadded key length UK = k_unpad.shape[0] # unpadded key length
UKV = k_unpad.shape[0] # unpadded query key length UKV = k_unpad.shape[0] # unpadded query key length
program = flashattn(batch, UQ, UKV, heads, dim, causal) kernel = flashattn(batch, UQ, UKV, heads, dim, causal)
kernel = tilelang.compile(program, [6])
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
......
...@@ -8,6 +8,7 @@ from functools import partial ...@@ -8,6 +8,7 @@ from functools import partial
num_split = 4 num_split = 4
@tilelang.jit(out_idx=[5])
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim] shape_q = [batch, seqlen_q, heads, dim]
...@@ -303,9 +304,8 @@ def main(): ...@@ -303,9 +304,8 @@ def main():
total_flops *= 0.5 total_flops *= 0.5
BLOCK_M = 128 BLOCK_M = 128
BLOCK_N = 64 # if D_HEAD <= 128 else 32 BLOCK_N = 64 # if D_HEAD <= 128 else 32
program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N)
ref_program_processed = partial(ref_program, causal=causal) ref_program_processed = partial(ref_program, causal=causal)
kernel = tilelang.compile(program, out_idx=[5])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks passed!") print("All checks passed!")
......
...@@ -6,6 +6,8 @@ import tilelang.language as T ...@@ -6,6 +6,8 @@ import tilelang.language as T
from einops import rearrange, einsum from einops import rearrange, einsum
import argparse import argparse
import itertools import itertools
from functools import lru_cache
from typing import Tuple, Dict
torch.random.manual_seed(0) torch.random.manual_seed(0)
...@@ -28,6 +30,30 @@ def get_configs(): ...@@ -28,6 +30,30 @@ def get_configs():
return configs return configs
@lru_cache(maxsize=1)
def get_heuristic_config() -> Tuple[Dict, int]:
# Get CUDA device properties
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version == 89:
cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=0, threads=128)
else:
cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=2, threads=128)
return cfg, sm_version
def get_pass_configs():
_, sm_version = get_heuristic_config()
if sm_version == 80:
return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}
else:
return {}
def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
...@@ -38,6 +64,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -38,6 +64,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // groups kv_group_num = heads // groups
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def kernel_func(block_N, block_H, num_split, num_stages, threads): def kernel_func(block_N, block_H, num_split, num_stages, threads):
part_shape = [batch, heads, num_split, dim] part_shape = [batch, heads, num_split, dim]
valid_block_H = min(block_H, kv_group_num) valid_block_H = min(block_H, kv_group_num)
...@@ -457,39 +484,8 @@ def main(batch: int = 1, ...@@ -457,39 +484,8 @@ def main(batch: int = 1,
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
if (not tune): if (not tune):
def get_heuristic_config() -> dict:
# Get CUDA device properties
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version == 89:
return {
"block_N": 128,
"block_H": 64,
"num_split": 16,
"num_stages": 0,
"threads": 128
}, sm_version
else:
return {
"block_N": 128,
"block_H": 64,
"num_split": 16,
"num_stages": 2,
"threads": 128
}, sm_version
config, sm_version = get_heuristic_config() config, sm_version = get_heuristic_config()
program = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)(**config) kernel = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)(**config)
if sm_version == 90:
kernel = tilelang.compile(
program, out_idx=[6], pass_configs={"tl.disable_tma_lower": True})
else:
kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16) q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16)
......
...@@ -8,6 +8,7 @@ from functools import partial ...@@ -8,6 +8,7 @@ from functools import partial
num_split = 4 num_split = 4
@tilelang.jit(out_idx=[5], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True})
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim] shape_q = [batch, seqlen_q, heads, dim]
...@@ -302,10 +303,8 @@ def main(): ...@@ -302,10 +303,8 @@ def main():
total_flops *= 0.5 total_flops *= 0.5
BLOCK_M = 128 BLOCK_M = 128
BLOCK_N = 64 # if D_HEAD <= 128 else 32 BLOCK_N = 64 # if D_HEAD <= 128 else 32
program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N)
ref_fn = partial(ref_program, causal=causal) ref_fn = partial(ref_program, causal=causal)
kernel = tilelang.compile(
program, out_idx=[5], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True})
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_fn, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_fn, rtol=0.01, atol=0.01)
...@@ -320,4 +319,4 @@ def main(): ...@@ -320,4 +319,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -2,6 +2,7 @@ import tilelang ...@@ -2,6 +2,7 @@ import tilelang
import tilelang.language as T import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
...@@ -27,11 +28,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -27,11 +28,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def main(): def main():
func = matmul(1024, 1024, 1024, 128, 128, 32) kernel = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
kernel = tilelang.compile(func, out_idx=-1)
import torch import torch
......
...@@ -164,6 +164,7 @@ def get_heuristic_config() -> dict: ...@@ -164,6 +164,7 @@ def get_heuristic_config() -> dict:
} }
@tl.jit(out_idx=[-1])
def matmul(M, def matmul(M,
N, N,
K, K,
...@@ -217,7 +218,7 @@ def main(m: int = 4096, ...@@ -217,7 +218,7 @@ def main(m: int = 4096,
kernel = result.kernel kernel = result.kernel
else: else:
config = get_heuristic_config() config = get_heuristic_config()
kernel = tl.compile(matmul(M, N, K, **config), out_idx=-1) kernel = matmul(M, N, K, **config)
# benchmark # benchmark
profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
......
...@@ -23,6 +23,7 @@ def make_swizzle_layout(shared_buf): ...@@ -23,6 +23,7 @@ def make_swizzle_layout(shared_buf):
return T.Layout(shape, transform_func) return T.Layout(shape, transform_func)
@tilelang.jit(out_idx=[2])
@simplify_prim_func @simplify_prim_func
def tl_matmul( def tl_matmul(
M, M,
...@@ -164,8 +165,7 @@ def ref_program(A, B): ...@@ -164,8 +165,7 @@ def ref_program(A, B):
def main(): def main():
M, N, K = 16384, 16384, 16384 M, N, K = 16384, 16384, 16384
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source() src_code = kernel.get_kernel_source()
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
......
...@@ -4,6 +4,7 @@ from tilelang.carver.arch import driver ...@@ -4,6 +4,7 @@ from tilelang.carver.arch import driver
import argparse import argparse
@tilelang.jit(out_idx=[-1])
def matmul_non_persistent(M, def matmul_non_persistent(M,
N, N,
K, K,
...@@ -41,6 +42,7 @@ def matmul_non_persistent(M, ...@@ -41,6 +42,7 @@ def matmul_non_persistent(M,
return main return main
@tilelang.jit(out_idx=[-1])
def matmul_persistent(M, def matmul_persistent(M,
N, N,
K, K,
...@@ -131,8 +133,7 @@ def main(): ...@@ -131,8 +133,7 @@ def main():
threads = 256 threads = 256
num_stages = 3 num_stages = 3
persistent_program = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
persistent_kernel = tilelang.compile(persistent_program, out_idx=-1)
persistent_profiler = persistent_kernel.get_profiler( persistent_profiler = persistent_kernel.get_profiler(
tensor_supply_type=tilelang.TensorSupplyType.Randn) tensor_supply_type=tilelang.TensorSupplyType.Randn)
persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
...@@ -141,9 +142,8 @@ def main(): ...@@ -141,9 +142,8 @@ def main():
print(f"Persistent GEMM Latency: {persistent_latency} ms") print(f"Persistent GEMM Latency: {persistent_latency} ms")
print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops") print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops")
non_persistent_program = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads,
num_stages) num_stages)
non_persistent_kernel = tilelang.compile(non_persistent_program, out_idx=-1)
non_persistent_profiler = non_persistent_kernel.get_profiler( non_persistent_profiler = non_persistent_kernel.get_profiler(
tensor_supply_type=tilelang.TensorSupplyType.Randn) tensor_supply_type=tilelang.TensorSupplyType.Randn)
non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
......
...@@ -2,6 +2,7 @@ import tilelang ...@@ -2,6 +2,7 @@ import tilelang
import tilelang.language as T import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
...@@ -40,11 +41,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -40,11 +41,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def main(): def main():
func = matmul(1024, 1024, 1024, 128, 128, 32) kernel = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
kernel = tilelang.compile(func, out_idx=-1)
import torch import torch
......
...@@ -11,6 +11,7 @@ def calc_diff(x, y): ...@@ -11,6 +11,7 @@ def calc_diff(x, y):
return 1 - sim return 1 - sim
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
@T.prim_func @T.prim_func
...@@ -38,9 +39,8 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): ...@@ -38,9 +39,8 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
def test_gemm_fp8(M, N, K, dtype): def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype) torch_dtype = map_torch_type(dtype)
func = matmul(M, N, K, 128, 128, 64, dtype) kernel = matmul(M, N, K, 128, 128, 64, dtype)
kernel = tilelang.compile(func, out_idx=-1)
a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
...@@ -62,4 +62,4 @@ def main(): ...@@ -62,4 +62,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -4,6 +4,7 @@ import tilelang.language as T ...@@ -4,6 +4,7 @@ import tilelang.language as T
from tilelang.utils.tensor import map_torch_type from tilelang.utils.tensor import map_torch_type
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
# for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128. # for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128.
# if block_K < 128, promote after 128/block_K iters. # if block_K < 128, promote after 128/block_K iters.
...@@ -56,9 +57,7 @@ def calc_diff(x, y): ...@@ -56,9 +57,7 @@ def calc_diff(x, y):
def test_gemm_fp8(M, N, K, dtype): def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype) torch_dtype = map_torch_type(dtype)
func = matmul(M, N, K, 128, 128, 64, dtype) kernel = matmul(M, N, K, 128, 128, 64, dtype)
kernel = tilelang.compile(func, out_idx=-1)
a = torch.rand(M, K, dtype=torch.float16, device='cuda') a = torch.rand(M, K, dtype=torch.float16, device='cuda')
a = (100 * (2 * a - 1)).to(dtype=torch_dtype) a = (100 * (2 * a - 1)).to(dtype=torch_dtype)
...@@ -80,4 +79,4 @@ def main(): ...@@ -80,4 +79,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -28,6 +28,7 @@ def make_swizzle_layout(shared_buf): ...@@ -28,6 +28,7 @@ def make_swizzle_layout(shared_buf):
return T.Layout(shape, transform_func) return T.Layout(shape, transform_func)
@tilelang.jit(out_idx=[2])
@simplify_prim_func @simplify_prim_func
def tl_matmul( def tl_matmul(
M, M,
...@@ -176,8 +177,7 @@ def tl_matmul( ...@@ -176,8 +177,7 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source() src_code = kernel.get_kernel_source()
print(src_code) print(src_code)
# src_code is the generated cuda source # src_code is the generated cuda source
...@@ -221,4 +221,4 @@ def main(): ...@@ -221,4 +221,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -2,6 +2,7 @@ import tilelang ...@@ -2,6 +2,7 @@ import tilelang
import tilelang.language as T import tilelang.language as T
@tilelang.jit
def matmul(M, def matmul(M,
N, N,
K, K,
...@@ -59,9 +60,7 @@ def main(): ...@@ -59,9 +60,7 @@ def main():
block_K = 32 block_K = 32
split_k = 4 split_k = 4
program = matmul(M, N, K, block_M, block_N, block_K, split_k) kernel = matmul(M, N, K, block_M, block_N, block_K, split_k)
kernel = tilelang.compile(program)
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment