Commit 3db18726 authored by Cunxiao Ni's avatar Cunxiao Ni Committed by LeiWang1999
Browse files

[Example] Update examples to use @tilelang.jit (#597)



* [Example] Update kernel compilation in examples to use @tilelang.jit

- Refactored multiple examples to eliminate the use of `tilelang.compile` for kernel creation, directly invoking the functions instead.
- Added `@tilelang.jit` decorators with appropriate output indices to enhance performance and maintainability.
- Improved code clarity by simplifying the kernel invocation process across various examples, ensuring consistency in how kernels are defined and executed.

* format

* Update example_tilelang_sparse_gqa_decode_varlen_indice.py

* Update example_dequant_gemm_fine_grained.py

* Update example_gemm_autotune.py

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 18889821
......@@ -29,6 +29,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
return dense_mask
@tilelang.jit(out_idx=[4])
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64
block_N = 64
......@@ -191,9 +192,8 @@ def test_topk_sparse_attention():
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# Run Triton kernel
program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=[4])
# Run tilelang kernel
kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
tilelang_output = kernel(q, k, v, block_mask)
......
......@@ -16,6 +16,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float"
kv_group_num = heads // heads_kv
@tilelang.jit(out_idx=[-1])
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen,
max_selected_blocks):
shape_q = [batch, heads, dim]
......@@ -200,7 +201,7 @@ class SparseFlashAttn(torch.nn.Module):
self.block_H = 64
program = flashattn(batch, heads, heads_kv, dim, dim_v)(
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.symbolic("num_split"),
......@@ -209,9 +210,6 @@ class SparseFlashAttn(torch.nn.Module):
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks"))
self.kernel = tilelang.compile(
program, out_idx=-1, target='cuda', execution_backend="cython")
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
......@@ -305,7 +303,11 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
is_causal_or_local=True,
max_splits=128)
program = flashattn(batch, heads, heads_kv, dim, dim_v)(
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.symbolic("num_split"),
......@@ -314,14 +316,6 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks"))
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
kernel = tilelang.compile(program, out_idx=-1, target='cuda', execution_backend="cython")
# print(kernel.get_kernel_source())
# output = kernel(query, key, value, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial)
return output
......@@ -455,7 +449,6 @@ def main(batch=8,
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks,
block_size)
# out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size)
sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size)
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
debug("output", ref, out, atol=1e-3, rtol=1e-3)
......
......@@ -17,6 +17,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float"
kv_group_num = heads // heads_kv
@tilelang.jit(out_idx=[-1])
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks):
shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim]
......@@ -186,7 +187,7 @@ class SparseFlashAttn(torch.nn.Module):
self.block_H = 64
program = flashattn(batch, heads, heads_kv, dim, dim_v)(
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.symbolic("num_split"),
......@@ -195,9 +196,6 @@ class SparseFlashAttn(torch.nn.Module):
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
num_blocks=T.symbolic("num_blocks"))
self.kernel = tilelang.compile(
program, out_idx=-1, target='cuda', execution_backend="cython")
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
......@@ -278,7 +276,7 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
is_causal_or_local=True,
max_splits=128)
program = flashattn(batch, heads, heads_kv, dim, dim_v)(
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.symbolic("num_split"),
......@@ -290,7 +288,6 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
kernel = tilelang.compile(program, out_idx=-1, target='cuda', execution_backend="cython")
# print(kernel.get_kernel_source())
output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)
......
......@@ -139,6 +139,7 @@ def get_best_config(M, N, K):
return autotuner.run(warmup=3, rep=20)
@tilelang.jit(out_idx=[-1])
def blocksparse_matmul(M,
N,
K,
......@@ -208,10 +209,9 @@ def main():
print(f"Best Kernel Latency: {best_latency:.6f} ms")
print(f"Reference Latency: {ref_latency:.6f} ms")
else:
func = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM,
DEFAULT_ENABLE_RASTERIZATION)
kernel = tilelang.compile(func, out_idx=-1)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
......
......@@ -9,6 +9,7 @@ dtype = "bfloat16"
accum_dtype = "float"
@tilelang.jit(out_idx=[2, 3])
def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
group_size = 128
fp8_min = -448.0
......@@ -176,13 +177,7 @@ def main():
print("batch_sizes:", batch_sizes)
print("M_max:", M_max)
program = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m)
kernel = tilelang.compile(
program,
out_idx=[2, 3],
target="cuda",
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m)
print(kernel.get_kernel_source())
# profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
......
......@@ -7,6 +7,7 @@ from tilelang.utils.tensor import torch_assert_close
tilelang.disable_cache()
@tilelang.jit(out_idx=[1, 2])
def per_token_cast_to_fp8(M, N, blk_m):
dtype = "float"
group_size = 128
......@@ -80,13 +81,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def main():
M, N, blk_m = 8192, 8192, 8
program = per_token_cast_to_fp8(M, N, blk_m)
kernel = tilelang.compile(
program,
out_idx=[1, 2],
target="cuda",
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
kernel = per_token_cast_to_fp8(M, N, blk_m)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
......
......@@ -2,13 +2,14 @@ from typing import Tuple
import torch
import tilelang.testing
import tilelang as TL
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
tilelang.testing.set_random_seed(42)
@tilelang.jit(out_idx=[2])
def tl_gemm(
M,
N,
......@@ -144,8 +145,7 @@ def calc_diff(x, y):
def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype):
gemm = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype)
kernel = TL.compile(gemm, out_idx=[])
kernel = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
......
......@@ -9,6 +9,7 @@ import argparse
tilelang.disable_cache()
@tilelang.jit(out_idx=[6])
def flashmla_decode(batch,
heads,
kv_head_num,
......@@ -287,9 +288,8 @@ if __name__ == "__main__":
BLOCK_H = 64
num_split = 4
program = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H,
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H,
num_split)
kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
input_tensors = profiler._get_inputs()
tilelang_output = kernel(*input_tensors)
......
......@@ -434,9 +434,8 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
num_kv_splits, block_size)
kernel = tilelang.compile(program, out_idx=[8])
def flash_mla_tilelang():
out = kernel(
......
......@@ -7,6 +7,7 @@ from einops import rearrange, einsum
import argparse
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
......@@ -287,8 +288,7 @@ def main():
BLOCK_H = 64
num_split = 1
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
......
......@@ -7,6 +7,7 @@ from tilelang.profiler import do_bench
import math
@tilelang.jit(out_idx=[8])
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split,
block_size):
scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e)
......@@ -323,9 +324,8 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
num_kv_splits, block_size)
kernel = tilelang.compile(program, out_idx=[8])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
def flash_mla_tilelang():
......
......@@ -8,6 +8,7 @@ from einops import rearrange, einsum
import argparse
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
......@@ -207,8 +208,7 @@ def main():
BLOCK_H = 64
num_split = 2
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
......
......@@ -7,6 +7,7 @@ from einops import rearrange, einsum
import argparse
@tilelang.jit(out_idx=[-1])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
......@@ -146,9 +147,7 @@ if __name__ == "__main__":
BLOCK_N = 64
BLOCK_H = 64
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H)
print(program)
kernel = tilelang.compile(program, out_idx=-1)
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
......
......@@ -8,6 +8,7 @@ import tilelang.testing
tilelang.testing.set_random_seed(42)
@tilelang.jit(out_idx=[-1])
def native_sparse_attention(
batch,
heads,
......@@ -130,7 +131,7 @@ def main():
B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16
groups = HQ // H
SEQ_LEN_Q = 1
program = native_sparse_attention(
kernel = native_sparse_attention(
batch=B,
heads=HQ,
seq_len=SEQ_LEN,
......@@ -140,7 +141,6 @@ def main():
selected_blocks=S,
)
kernel = tilelang.compile(program, out_idx=-1)
Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
......
......@@ -8,6 +8,7 @@ import tilelang.testing
tilelang.testing.set_random_seed(0)
@tilelang.jit(out_idx=[-1])
def native_sparse_attention(batch,
heads,
seq_len,
......@@ -128,7 +129,7 @@ def native_sparse_attention(batch,
def main():
B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1
program = native_sparse_attention(
kernel = native_sparse_attention(
batch=B,
heads=HQ,
seq_len=SEQ_LEN,
......@@ -139,7 +140,6 @@ def main():
selected_blocks=S,
scale=scale,
)
kernel = tilelang.compile(program, out_idx=-1)
print(kernel.get_kernel_source())
torch.random.manual_seed(0)
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
......
......@@ -16,6 +16,7 @@ from reference import naive_nsa
from einops import rearrange
@tilelang.jit
def native_sparse_attention_varlen(batch,
heads,
c_seq_len,
......@@ -171,7 +172,7 @@ def parallel_nsa_fwd(
BS = block_size
WS = window_size
program = native_sparse_attention_varlen(
kernel = native_sparse_attention_varlen(
batch=batch,
heads=HQ,
c_seq_len=C_SEQ_LEN,
......@@ -182,8 +183,6 @@ def parallel_nsa_fwd(
selected_blocks=S,
)
kernel = tilelang.compile(program)
o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device)
kernel(
q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D),
......
......@@ -8,6 +8,7 @@ import tilelang.language as T
tilelang.testing.set_random_seed(0)
@tilelang.jit(out_idx=[2])
def matmul(
M,
N,
......@@ -98,7 +99,7 @@ def run_gemm(
num_stages=3,
num_threads=128,
):
program = matmul(
kernel = matmul(
M,
N,
K,
......@@ -112,7 +113,6 @@ def run_gemm(
num_threads,
)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
out = profiler.run_once()
......@@ -435,7 +435,6 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
def main():
test_run_dequantize_gemm()
test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4()
if __name__ == "__main__":
......
......@@ -54,6 +54,7 @@ def torch_convert(tensor):
return new_tensor
@tilelang.jit(out_idx=[1])
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
......@@ -89,7 +90,7 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
def test_fp4_fp16_convert_close():
N, K = 256, 256
block_N, block_K = 64, 64
program = test_convert(
kernel = test_convert(
N,
K,
block_N,
......@@ -97,8 +98,6 @@ def test_fp4_fp16_convert_close():
"float16",
)
kernel = tilelang.compile(program, out_idx=[1])
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = kernel(B)
ref_out = torch_convert(B)
......@@ -128,6 +127,7 @@ def get_configs():
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
......@@ -270,10 +270,9 @@ def main(m=256, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if (not tune):
program = matmul(
kernel = matmul(
m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
......
......@@ -7,6 +7,7 @@ from tilelang.quantize import (
_tir_packed_int_to_int_convert,)
@tilelang.jit
def dequantize_gemv(
M: int,
N: int,
......@@ -173,12 +174,10 @@ def main() -> None:
group_size = -1
with_scaling = False
program = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype,
kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype,
source_format, n_partition, reduce_thread, fast_decoding, trans_A,
trans_B, group_size, with_scaling)
kernel = tilelang.compile(program)
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits
A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda()
......
......@@ -7,6 +7,7 @@ tilelang.testing.set_random_seed(0)
tilelang.disable_cache()
@tilelang.jit(pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8})
def matmul_dynamic_mnk(
block_M,
block_N,
......@@ -60,15 +61,9 @@ def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtyp
print(
f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}"
)
program = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
kernel = tilelang.compile(
program, pass_configs={
"tl.disable_dynamic_tail_split": True,
"tl.dynamic_alignment": 8
})
import torch
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment