Commit 3db18726 authored by Cunxiao Ni's avatar Cunxiao Ni Committed by LeiWang1999
Browse files

[Example] Update examples to use @tilelang.jit (#597)



* [Example] Update kernel compilation in examples to use @tilelang.jit

- Refactored multiple examples to eliminate the use of `tilelang.compile` for kernel creation, directly invoking the functions instead.
- Added `@tilelang.jit` decorators with appropriate output indices to enhance performance and maintainability.
- Improved code clarity by simplifying the kernel invocation process across various examples, ensuring consistency in how kernels are defined and executed.

* format

* Update example_tilelang_sparse_gqa_decode_varlen_indice.py

* Update example_dequant_gemm_fine_grained.py

* Update example_gemm_autotune.py

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 18889821
...@@ -54,6 +54,7 @@ print(f"{iters_per_tile=} ") ...@@ -54,6 +54,7 @@ print(f"{iters_per_tile=} ")
sm_patition_factor = max(blocking_tiles // total_sm, 1) sm_patition_factor = max(blocking_tiles // total_sm, 1)
@tilelang.jit
def tl_matmul_streamk( def tl_matmul_streamk(
M, M,
N, N,
...@@ -170,7 +171,7 @@ def tl_matmul_streamk( ...@@ -170,7 +171,7 @@ def tl_matmul_streamk(
def main(): def main():
_tl_matmul_streamk = tl_matmul_streamk( kernel = tl_matmul_streamk(
m, m,
n, n,
k, k,
...@@ -187,7 +188,6 @@ def main(): ...@@ -187,7 +188,6 @@ def main():
64, 64,
) )
kernel = tilelang.compile(_tl_matmul_streamk)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16) b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16)
......
...@@ -11,6 +11,7 @@ def ref_program(A, B): ...@@ -11,6 +11,7 @@ def ref_program(A, B):
return A @ B.T return A @ B.T
@tl.jit(out_idx=[-1])
def naive_gemv( def naive_gemv(
N: int, N: int,
K: int, K: int,
...@@ -44,6 +45,7 @@ def naive_gemv( ...@@ -44,6 +45,7 @@ def naive_gemv(
return main return main
@tl.jit(out_idx=[-1])
def naive_splitk_gemv( def naive_splitk_gemv(
N: int, N: int,
K: int, K: int,
...@@ -79,6 +81,7 @@ def naive_splitk_gemv( ...@@ -79,6 +81,7 @@ def naive_splitk_gemv(
return main return main
@tl.jit(out_idx=[-1])
def splitk_gemv( def splitk_gemv(
N: int, N: int,
K: int, K: int,
...@@ -118,6 +121,7 @@ def splitk_gemv( ...@@ -118,6 +121,7 @@ def splitk_gemv(
return main return main
@tl.jit(out_idx=[-1])
def splitk_gemv_vectorized( def splitk_gemv_vectorized(
N: int, N: int,
K: int, K: int,
...@@ -158,6 +162,7 @@ def splitk_gemv_vectorized( ...@@ -158,6 +162,7 @@ def splitk_gemv_vectorized(
return main return main
@tl.jit(out_idx=[-1])
def splitk_gemv_vectorized_tvm( def splitk_gemv_vectorized_tvm(
N: int, N: int,
K: int, K: int,
...@@ -290,7 +295,6 @@ def get_best_config(N, K): ...@@ -290,7 +295,6 @@ def get_best_config(N, K):
def check_correctness_and_bench(kernel, N, K, bench_ref=True): def check_correctness_and_bench(kernel, N, K, bench_ref=True):
kernel = tl.compile(kernel, out_idx=-1)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2) profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2)
if bench_ref: if bench_ref:
...@@ -316,7 +320,6 @@ def main(): ...@@ -316,7 +320,6 @@ def main():
best_result = get_best_config(N, K) best_result = get_best_config(N, K)
best_config = best_result.config best_config = best_result.config
kernel = splitk_gemv_vectorized_tvm(N, K, **best_config) kernel = splitk_gemv_vectorized_tvm(N, K, **best_config)
kernel = tl.compile(kernel, out_idx=-1)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500) latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500)
print(f"Torch Latency: {latency} ms") print(f"Torch Latency: {latency} ms")
......
...@@ -7,6 +7,11 @@ import tilelang.language as T ...@@ -7,6 +7,11 @@ import tilelang.language as T
tilelang.disable_cache() tilelang.disable_cache()
@tilelang.jit(
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def grouped_gemm_fwd(batch_sum, def grouped_gemm_fwd(batch_sum,
batch_count, batch_count,
K, K,
...@@ -103,16 +108,9 @@ class _GroupedGEMM(torch.autograd.Function): ...@@ -103,16 +108,9 @@ class _GroupedGEMM(torch.autograd.Function):
batch_padded_offsets = torch.tensor( batch_padded_offsets = torch.tensor(
batch_padded_offsets_list, device=a.device, dtype=torch.int32) batch_padded_offsets_list, device=a.device, dtype=torch.int32)
program = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K,
num_stages, threads) num_stages, threads)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets) o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets)
ctx.save_for_backward(a, b, batch_sizes, batch_offsets) ctx.save_for_backward(a, b, batch_sizes, batch_offsets)
ctx.batch_sum = batch_sum ctx.batch_sum = batch_sum
...@@ -139,15 +137,8 @@ class _GroupedGEMM(torch.autograd.Function): ...@@ -139,15 +137,8 @@ class _GroupedGEMM(torch.autograd.Function):
return x return x
A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)] A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)]
program = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K,
num_stages, threads) num_stages, threads)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
dB = kernel(A, grad_output, batch_sizes, batch_offsets) dB = kernel(A, grad_output, batch_sizes, batch_offsets)
return None, dB, None return None, dB, None
...@@ -198,6 +189,11 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): ...@@ -198,6 +189,11 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets
@tilelang.jit(
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def grouped_gemm_bwd(batch_sum, def grouped_gemm_bwd(batch_sum,
batch_count, batch_count,
M, M,
......
...@@ -7,6 +7,11 @@ import math ...@@ -7,6 +7,11 @@ import math
tilelang.disable_cache() tilelang.disable_cache()
@tilelang.jit(
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
""" """
Perform grouped matrix multiplication using PyTorch. Perform grouped matrix multiplication using PyTorch.
...@@ -39,6 +44,11 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): ...@@ -39,6 +44,11 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
return output return output
@tilelang.jit(
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def grouped_gemm(batch_sizes_list, def grouped_gemm(batch_sizes_list,
K, K,
N, N,
...@@ -140,14 +150,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, ...@@ -140,14 +150,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
profile=False): profile=False):
padding_M = block_M padding_M = block_M
batch_sum = sum(batch_sizes_list) batch_sum = sum(batch_sizes_list)
program = grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, num_stages, threads) kernel = grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, num_stages, threads)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
# print(kernel.get_kernel_source()) # print(kernel.get_kernel_source())
device = torch.device("cuda") device = torch.device("cuda")
......
...@@ -13,6 +13,7 @@ def is_pow_of_2(n): ...@@ -13,6 +13,7 @@ def is_pow_of_2(n):
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
@tilelang.jit(out_idx=[1])
def hadamard(b, n, dtype): def hadamard(b, n, dtype):
assert is_pow_of_2(n), "n must be a power of 2" assert is_pow_of_2(n), "n must be a power of 2"
assert 2 <= n <= 32768, "n must be in [2, 32768]" assert 2 <= n <= 32768, "n must be in [2, 32768]"
...@@ -142,7 +143,7 @@ def main(): ...@@ -142,7 +143,7 @@ def main():
B, D = args.batch, args.dim B, D = args.batch, args.dim
x = torch.randn((B, D), device='cuda') x = torch.randn((B, D), device='cuda')
kernel = tilelang.compile(hadamard(B, D, 'float32'), out_idx=1) kernel = hadamard(B, D, 'float32')
y = kernel(x) y = kernel(x)
y_ref = ref_program(x) y_ref = ref_program(x)
torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2)
......
...@@ -7,6 +7,7 @@ import argparse ...@@ -7,6 +7,7 @@ import argparse
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA
@tl.jit(out_idx=[4, 5, 6])
def chunk_linear_attn_bwd_kernel( def chunk_linear_attn_bwd_kernel(
B, B,
S, S,
...@@ -155,8 +156,7 @@ def main(): ...@@ -155,8 +156,7 @@ def main():
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True)
do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
fn = chunk_linear_attn_bwd_kernel(B, S, H, D, D) kernel = chunk_linear_attn_bwd_kernel(B, S, H, D, D)
kernel = tl.compile(fn, out_idx=[4, 5, 6], target='cuda')
dq, dk, dv = postprocess(*kernel(q, k, v, do)) dq, dk, dv = postprocess(*kernel(q, k, v, do))
o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
o_ref.backward(do, retain_graph=True) o_ref.backward(do, retain_graph=True)
......
...@@ -7,6 +7,7 @@ import argparse ...@@ -7,6 +7,7 @@ import argparse
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA
@tl.jit(out_idx=[3, 4])
def chunk_linear_attn_fwd_kernel( def chunk_linear_attn_fwd_kernel(
B, B,
S, S,
...@@ -97,8 +98,7 @@ def main(): ...@@ -97,8 +98,7 @@ def main():
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
fn = chunk_linear_attn_fwd_kernel(B, S, H, D, D) kernel = chunk_linear_attn_fwd_kernel(B, S, H, D, D)
kernel = tl.compile(fn, out_idx=[3, 4], target='cuda')
o, h = postprocess(*kernel(q, k, v)) o, h = postprocess(*kernel(q, k, v))
o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
......
...@@ -79,6 +79,7 @@ def get_configs(): ...@@ -79,6 +79,7 @@ def get_configs():
return configs return configs
@tilelang.jit(out_idx=[7])
def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False): def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -229,10 +230,9 @@ if __name__ == "__main__": ...@@ -229,10 +230,9 @@ if __name__ == "__main__":
total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate
if (not args.tune): if (not args.tune):
program = chunk_scan_fwd( kernel = chunk_scan_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)( batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)(
block_M=64, block_N=64, block_K=64, block_Dstate=128, num_stages=2, threads=128) block_M=64, block_N=64, block_K=64, block_Dstate=128, num_stages=2, threads=128)
kernel = tilelang.compile(program, out_idx=[7])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
......
...@@ -62,6 +62,7 @@ def get_configs(): ...@@ -62,6 +62,7 @@ def get_configs():
return configs return configs
@tilelang.jit(out_idx=[4])
def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False): def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -166,10 +167,9 @@ if __name__ == "__main__": ...@@ -166,10 +167,9 @@ if __name__ == "__main__":
total_flops = 2 * batch * seq_len * heads * dim * dstate total_flops = 2 * batch * seq_len * heads * dim * dstate
if (not args.tune): if (not args.tune):
program = chunk_state_fwd( kernel = chunk_state_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)( batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)(
block_M=64, block_N=128, block_K=64, num_stages=4, threads=128) block_M=64, block_N=128, block_K=64, num_stages=4, threads=128)
kernel = tilelang.compile(program, out_idx=[4])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
......
...@@ -4,6 +4,7 @@ import tilelang ...@@ -4,6 +4,7 @@ import tilelang
import tilelang.language as T import tilelang.language as T
@tilelang.jit(out_idx=[4])
def retnet(batch, heads, seq_len, dim_qk, dim_v, block_M, block_N): def retnet(batch, heads, seq_len, dim_qk, dim_v, block_M, block_N):
qk_shape = [batch, seq_len, heads, dim_qk] qk_shape = [batch, seq_len, heads, dim_qk]
v_shape = [batch, seq_len, heads, dim_v] v_shape = [batch, seq_len, heads, dim_v]
...@@ -179,8 +180,7 @@ if __name__ == "__main__": ...@@ -179,8 +180,7 @@ if __name__ == "__main__":
total_flops = 2.0 * BATCH * H * N_CTX * N_CTX * (dim_qk + dim_v) total_flops = 2.0 * BATCH * H * N_CTX * N_CTX * (dim_qk + dim_v)
BLOCK_M = 64 BLOCK_M = 64
BLOCK_N = 64 BLOCK_N = 64
program = retnet(BATCH, H, N_CTX, dim_qk, dim_v, BLOCK_M, BLOCK_N) kernel = retnet(BATCH, H, N_CTX, dim_qk, dim_v, BLOCK_M, BLOCK_N)
kernel = tilelang.compile(program, out_idx=[4])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
ins = profiler._get_inputs() ins = profiler._get_inputs()
......
...@@ -33,6 +33,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k): ...@@ -33,6 +33,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
return main return main
@tilelang.jit(out_idx=[-1], pass_configs={"tl.disable_tma_lower": True})
def rms_norm(M, N, blk_m): def rms_norm(M, N, blk_m):
dtype = "float" dtype = "float"
...@@ -64,13 +65,7 @@ def ref_program(x): ...@@ -64,13 +65,7 @@ def ref_program(x):
if __name__ == "__main__": if __name__ == "__main__":
M, N, blk_m, blk_k = 8192, 8192, 1, 512 M, N, blk_m, blk_k = 8192, 8192, 1, 512
program = rms_norm(M, N, blk_m) kernel = rms_norm(M, N, blk_m)
kernel = tilelang.compile(
program,
out_idx=-1,
target="cuda",
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
...@@ -78,4 +73,4 @@ if __name__ == "__main__": ...@@ -78,4 +73,4 @@ if __name__ == "__main__":
latency = profiler.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
latency = profiler.do_bench(warmup=500) latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
\ No newline at end of file
...@@ -5,6 +5,7 @@ from tilelang.profiler import do_bench ...@@ -5,6 +5,7 @@ from tilelang.profiler import do_bench
from typing import Callable from typing import Callable
@tl.jit(out_idx=[1])
def softmax_kernel( def softmax_kernel(
M, M,
N, N,
...@@ -59,10 +60,9 @@ def softmax_kernel( ...@@ -59,10 +60,9 @@ def softmax_kernel(
M = 8192 M = 8192
N = 8192 N = 8192
fn = softmax_kernel(M, N) kernel = softmax_kernel(M, N)
dtype = torch.float16 dtype = torch.float16
X = torch.randn(M, N, dtype=dtype, device="cuda") X = torch.randn(M, N, dtype=dtype, device="cuda")
kernel = tl.compile(fn, out_idx=[1], target="cuda")
Y = kernel(X).to(dtype) Y = kernel(X).to(dtype)
Y_ref = X.softmax(dim=1) Y_ref = X.softmax(dim=1)
......
...@@ -29,6 +29,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -29,6 +29,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
return dense_mask return dense_mask
@tilelang.jit(out_idx=[4])
def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal): def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal):
block_M = 64 block_M = 64
block_N = 64 block_N = 64
...@@ -174,10 +175,9 @@ def test_topk_sparse_attention(): ...@@ -174,10 +175,9 @@ def test_topk_sparse_attention():
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# Run Triton kernel # Run tilelang kernel
program = blocksparse_flashattn( kernel = blocksparse_flashattn(
BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=[4])
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
...@@ -224,10 +224,8 @@ def test_topk_sparse_attention_qlen_lt_klen(): ...@@ -224,10 +224,8 @@ def test_topk_sparse_attention_qlen_lt_klen():
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
program = blocksparse_flashattn( kernel = blocksparse_flashattn(
BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True)
print(program)
kernel = tilelang.compile(program, out_idx=[4])
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
...@@ -265,4 +263,4 @@ def main(): ...@@ -265,4 +263,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -7,6 +7,7 @@ import tilelang.language as T ...@@ -7,6 +7,7 @@ import tilelang.language as T
from einops import rearrange, einsum from einops import rearrange, einsum
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
...@@ -172,8 +173,7 @@ def main(): ...@@ -172,8 +173,7 @@ def main():
BLOCK_H = 64 BLOCK_H = 64
num_split = 1 num_split = 1
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
......
...@@ -4,6 +4,7 @@ import tilelang.language as T ...@@ -4,6 +4,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
num_stages = 2 num_stages = 2
...@@ -57,19 +58,10 @@ def main(): ...@@ -57,19 +58,10 @@ def main():
block_M = 128 block_M = 128
block_N = 128 block_N = 128
block_K = 64 block_K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module jit_kernel = matmul(M, N, K, block_M, block_N, block_K)
func = matmul(M, N, K, block_M, block_N, block_K)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(func, out_idx=[2])
# 3. Test the kernel in Python with PyTorch data
import torch import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16) a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16) b = torch.randn(K, N, device="cuda", dtype=torch.float16)
......
...@@ -4,6 +4,7 @@ import tilelang.language as T ...@@ -4,6 +4,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul_warp_specialize_copy_0_gemm_1(M, def matmul_warp_specialize_copy_0_gemm_1(M,
N, N,
K, K,
...@@ -56,19 +57,8 @@ def main(): ...@@ -56,19 +57,8 @@ def main():
block_N = 128 block_N = 128
block_K = 64 block_K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module jit_kernel = matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K)
func = matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
)
# 3. Test the kernel in Python with PyTorch data
import torch import torch
# Create random input tensors on the GPU # Create random input tensors on the GPU
......
...@@ -4,6 +4,7 @@ import tilelang.language as T ...@@ -4,6 +4,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul_warp_specialize_copy_1_gemm_0(M, def matmul_warp_specialize_copy_1_gemm_0(M,
N, N,
K, K,
...@@ -56,22 +57,10 @@ def main(): ...@@ -56,22 +57,10 @@ def main():
block_N = 128 block_N = 128
block_K = 64 block_K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K)
func = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
)
# 3. Test the kernel in Python with PyTorch data
import torch import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16) a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16) b = torch.randn(K, N, device="cuda", dtype=torch.float16)
......
...@@ -6,6 +6,12 @@ tilelang.disable_cache() ...@@ -6,6 +6,12 @@ tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
# tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def matmul_warp_specialize_copy_1_gemm_0(M, def matmul_warp_specialize_copy_1_gemm_0(M,
N, N,
K, K,
...@@ -59,21 +65,7 @@ def main(): ...@@ -59,21 +65,7 @@ def main():
block_N = 128 block_N = 128
block_K = 64 block_K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K)
func = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K)
# print(func.script())
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
# tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
print(jit_kernel.get_kernel_source()) print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data # 3. Test the kernel in Python with PyTorch data
import torch import torch
......
...@@ -4,6 +4,7 @@ import tilelang.language as T ...@@ -4,6 +4,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
...@@ -48,15 +49,10 @@ def main(): ...@@ -48,15 +49,10 @@ def main():
block_M = 128 block_M = 128
block_N = 128 block_N = 128
block_K = 64 block_K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(M, N, K, block_M, block_N, block_K)
# 2. Compile the kernel into a torch function jit_kernel = matmul(M, N, K, block_M, block_N, block_K)
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
tilelang.disable_cache() tilelang.disable_cache()
jit_kernel = tilelang.compile(func, out_idx=[2])
# 3. Test the kernel in Python with PyTorch data # 3. Test the kernel in Python with PyTorch data
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment