Commit 3db18726 authored by Cunxiao Ni's avatar Cunxiao Ni Committed by LeiWang1999
Browse files

[Example] Update examples to use @tilelang.jit (#597)



* [Example] Update kernel compilation in examples to use @tilelang.jit

- Refactored multiple examples to eliminate the use of `tilelang.compile` for kernel creation, directly invoking the functions instead.
- Added `@tilelang.jit` decorators with appropriate output indices to enhance performance and maintainability.
- Improved code clarity by simplifying the kernel invocation process across various examples, ensuring consistency in how kernels are defined and executed.

* format

* Update example_tilelang_sparse_gqa_decode_varlen_indice.py

* Update example_dequant_gemm_fine_grained.py

* Update example_gemm_autotune.py

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 18889821
...@@ -29,6 +29,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -29,6 +29,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
return dense_mask return dense_mask
@tilelang.jit(out_idx=[4])
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64 block_M = 64
block_N = 64 block_N = 64
...@@ -191,9 +192,8 @@ def test_topk_sparse_attention(): ...@@ -191,9 +192,8 @@ def test_topk_sparse_attention():
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# Run Triton kernel # Run tilelang kernel
program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=[4])
tilelang_output = kernel(q, k, v, block_mask) tilelang_output = kernel(q, k, v, block_mask)
......
...@@ -16,6 +16,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -16,6 +16,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // heads_kv kv_group_num = heads // heads_kv
@tilelang.jit(out_idx=[-1])
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen,
max_selected_blocks): max_selected_blocks):
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
...@@ -200,7 +201,7 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -200,7 +201,7 @@ class SparseFlashAttn(torch.nn.Module):
self.block_H = 64 self.block_H = 64
program = flashattn(batch, heads, heads_kv, dim, dim_v)( self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size, block_N=block_size,
block_H=self.block_H, block_H=self.block_H,
num_split=T.symbolic("num_split"), num_split=T.symbolic("num_split"),
...@@ -209,9 +210,6 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -209,9 +210,6 @@ class SparseFlashAttn(torch.nn.Module):
max_cache_seqlen=T.symbolic("max_cache_seqlen"), max_cache_seqlen=T.symbolic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks")) max_selected_blocks=T.symbolic("max_selected_blocks"))
self.kernel = tilelang.compile(
program, out_idx=-1, target='cuda', execution_backend="cython")
props = torch.cuda.get_device_properties(torch.device("cuda:0")) props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count self.num_sm = props.multi_processor_count
...@@ -305,7 +303,11 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql ...@@ -305,7 +303,11 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
is_causal_or_local=True, is_causal_or_local=True,
max_splits=128) max_splits=128)
program = flashattn(batch, heads, heads_kv, dim, dim_v)( glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size, block_N=block_size,
block_H=block_H, block_H=block_H,
num_split=T.symbolic("num_split"), num_split=T.symbolic("num_split"),
...@@ -314,14 +316,6 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql ...@@ -314,14 +316,6 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
max_cache_seqlen=T.symbolic("max_cache_seqlen"), max_cache_seqlen=T.symbolic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks")) max_selected_blocks=T.symbolic("max_selected_blocks"))
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
kernel = tilelang.compile(program, out_idx=-1, target='cuda', execution_backend="cython")
# print(kernel.get_kernel_source())
# output = kernel(query, key, value, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial)
return output return output
...@@ -455,7 +449,6 @@ def main(batch=8, ...@@ -455,7 +449,6 @@ def main(batch=8,
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks,
block_size) block_size)
# out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size)
sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size)
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
debug("output", ref, out, atol=1e-3, rtol=1e-3) debug("output", ref, out, atol=1e-3, rtol=1e-3)
......
...@@ -17,6 +17,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -17,6 +17,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // heads_kv kv_group_num = heads // heads_kv
@tilelang.jit(out_idx=[-1])
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks):
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim] shape_k = [batch, max_cache_seqlen, heads_kv, dim]
...@@ -186,7 +187,7 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -186,7 +187,7 @@ class SparseFlashAttn(torch.nn.Module):
self.block_H = 64 self.block_H = 64
program = flashattn(batch, heads, heads_kv, dim, dim_v)( self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size, block_N=block_size,
block_H=self.block_H, block_H=self.block_H,
num_split=T.symbolic("num_split"), num_split=T.symbolic("num_split"),
...@@ -195,9 +196,6 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -195,9 +196,6 @@ class SparseFlashAttn(torch.nn.Module):
max_cache_seqlen=T.symbolic("max_cache_seqlen"), max_cache_seqlen=T.symbolic("max_cache_seqlen"),
num_blocks=T.symbolic("num_blocks")) num_blocks=T.symbolic("num_blocks"))
self.kernel = tilelang.compile(
program, out_idx=-1, target='cuda', execution_backend="cython")
props = torch.cuda.get_device_properties(torch.device("cuda:0")) props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count self.num_sm = props.multi_processor_count
...@@ -278,7 +276,7 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, ...@@ -278,7 +276,7 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
is_causal_or_local=True, is_causal_or_local=True,
max_splits=128) max_splits=128)
program = flashattn(batch, heads, heads_kv, dim, dim_v)( kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size, block_N=block_size,
block_H=block_H, block_H=block_H,
num_split=T.symbolic("num_split"), num_split=T.symbolic("num_split"),
...@@ -290,7 +288,6 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, ...@@ -290,7 +288,6 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
Output_partial = torch.empty((batch, heads, num_split, dim_v), Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32, dtype=torch.float32,
device='cuda') device='cuda')
kernel = tilelang.compile(program, out_idx=-1, target='cuda', execution_backend="cython")
# print(kernel.get_kernel_source()) # print(kernel.get_kernel_source())
output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)
......
...@@ -139,6 +139,7 @@ def get_best_config(M, N, K): ...@@ -139,6 +139,7 @@ def get_best_config(M, N, K):
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
@tilelang.jit(out_idx=[-1])
def blocksparse_matmul(M, def blocksparse_matmul(M,
N, N,
K, K,
...@@ -208,10 +209,9 @@ def main(): ...@@ -208,10 +209,9 @@ def main():
print(f"Best Kernel Latency: {best_latency:.6f} ms") print(f"Best Kernel Latency: {best_latency:.6f} ms")
print(f"Reference Latency: {ref_latency:.6f} ms") print(f"Reference Latency: {ref_latency:.6f} ms")
else: else:
func = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K, kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM, DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM,
DEFAULT_ENABLE_RASTERIZATION) DEFAULT_ENABLE_RASTERIZATION)
kernel = tilelang.compile(func, out_idx=-1)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
......
...@@ -9,6 +9,7 @@ dtype = "bfloat16" ...@@ -9,6 +9,7 @@ dtype = "bfloat16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[2, 3])
def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
group_size = 128 group_size = 128
fp8_min = -448.0 fp8_min = -448.0
...@@ -176,13 +177,7 @@ def main(): ...@@ -176,13 +177,7 @@ def main():
print("batch_sizes:", batch_sizes) print("batch_sizes:", batch_sizes)
print("M_max:", M_max) print("M_max:", M_max)
program = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m) kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m)
kernel = tilelang.compile(
program,
out_idx=[2, 3],
target="cuda",
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
# profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) # profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
......
...@@ -7,6 +7,7 @@ from tilelang.utils.tensor import torch_assert_close ...@@ -7,6 +7,7 @@ from tilelang.utils.tensor import torch_assert_close
tilelang.disable_cache() tilelang.disable_cache()
@tilelang.jit(out_idx=[1, 2])
def per_token_cast_to_fp8(M, N, blk_m): def per_token_cast_to_fp8(M, N, blk_m):
dtype = "float" dtype = "float"
group_size = 128 group_size = 128
...@@ -80,13 +81,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ...@@ -80,13 +81,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def main(): def main():
M, N, blk_m = 8192, 8192, 8 M, N, blk_m = 8192, 8192, 8
program = per_token_cast_to_fp8(M, N, blk_m) kernel = per_token_cast_to_fp8(M, N, blk_m)
kernel = tilelang.compile(
program,
out_idx=[1, 2],
target="cuda",
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
......
...@@ -2,13 +2,14 @@ from typing import Tuple ...@@ -2,13 +2,14 @@ from typing import Tuple
import torch import torch
import tilelang.testing import tilelang.testing
import tilelang as TL import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.utils.tensor import map_torch_type from tilelang.utils.tensor import map_torch_type
tilelang.testing.set_random_seed(42) tilelang.testing.set_random_seed(42)
@tilelang.jit(out_idx=[2])
def tl_gemm( def tl_gemm(
M, M,
N, N,
...@@ -144,8 +145,7 @@ def calc_diff(x, y): ...@@ -144,8 +145,7 @@ def calc_diff(x, y):
def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype): def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype):
gemm = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype) kernel = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype)
kernel = TL.compile(gemm, out_idx=[])
src_code = kernel.get_kernel_source() src_code = kernel.get_kernel_source()
# src_code is the generated cuda source # src_code is the generated cuda source
......
...@@ -9,6 +9,7 @@ import argparse ...@@ -9,6 +9,7 @@ import argparse
tilelang.disable_cache() tilelang.disable_cache()
@tilelang.jit(out_idx=[6])
def flashmla_decode(batch, def flashmla_decode(batch,
heads, heads,
kv_head_num, kv_head_num,
...@@ -287,9 +288,8 @@ if __name__ == "__main__": ...@@ -287,9 +288,8 @@ if __name__ == "__main__":
BLOCK_H = 64 BLOCK_H = 64
num_split = 4 num_split = 4
program = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H,
num_split) num_split)
kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
input_tensors = profiler._get_inputs() input_tensors = profiler._get_inputs()
tilelang_output = kernel(*input_tensors) tilelang_output = kernel(*input_tensors)
......
...@@ -434,9 +434,8 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size ...@@ -434,9 +434,8 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
num_kv_splits, block_size) num_kv_splits, block_size)
kernel = tilelang.compile(program, out_idx=[8])
def flash_mla_tilelang(): def flash_mla_tilelang():
out = kernel( out = kernel(
......
...@@ -7,6 +7,7 @@ from einops import rearrange, einsum ...@@ -7,6 +7,7 @@ from einops import rearrange, einsum
import argparse import argparse
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
...@@ -287,8 +288,7 @@ def main(): ...@@ -287,8 +288,7 @@ def main():
BLOCK_H = 64 BLOCK_H = 64
num_split = 1 num_split = 1
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500) latency = profiler.do_bench(warmup=500)
...@@ -297,4 +297,4 @@ def main(): ...@@ -297,4 +297,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -7,6 +7,7 @@ from tilelang.profiler import do_bench ...@@ -7,6 +7,7 @@ from tilelang.profiler import do_bench
import math import math
@tilelang.jit(out_idx=[8])
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split,
block_size): block_size):
scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e)
...@@ -323,9 +324,8 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s ...@@ -323,9 +324,8 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
num_kv_splits, block_size) num_kv_splits, block_size)
kernel = tilelang.compile(program, out_idx=[8])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
def flash_mla_tilelang(): def flash_mla_tilelang():
......
...@@ -8,6 +8,7 @@ from einops import rearrange, einsum ...@@ -8,6 +8,7 @@ from einops import rearrange, einsum
import argparse import argparse
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
...@@ -207,8 +208,7 @@ def main(): ...@@ -207,8 +208,7 @@ def main():
BLOCK_H = 64 BLOCK_H = 64
num_split = 2 num_split = 2
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
......
...@@ -7,6 +7,7 @@ from einops import rearrange, einsum ...@@ -7,6 +7,7 @@ from einops import rearrange, einsum
import argparse import argparse
@tilelang.jit(out_idx=[-1])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
...@@ -146,9 +147,7 @@ if __name__ == "__main__": ...@@ -146,9 +147,7 @@ if __name__ == "__main__":
BLOCK_N = 64 BLOCK_N = 64
BLOCK_H = 64 BLOCK_H = 64
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H) kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H)
print(program)
kernel = tilelang.compile(program, out_idx=-1)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
latency = profiler.do_bench(warmup=500) latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms") print(f"Latency: {latency} ms")
......
...@@ -8,6 +8,7 @@ import tilelang.testing ...@@ -8,6 +8,7 @@ import tilelang.testing
tilelang.testing.set_random_seed(42) tilelang.testing.set_random_seed(42)
@tilelang.jit(out_idx=[-1])
def native_sparse_attention( def native_sparse_attention(
batch, batch,
heads, heads,
...@@ -130,7 +131,7 @@ def main(): ...@@ -130,7 +131,7 @@ def main():
B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16 B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16
groups = HQ // H groups = HQ // H
SEQ_LEN_Q = 1 SEQ_LEN_Q = 1
program = native_sparse_attention( kernel = native_sparse_attention(
batch=B, batch=B,
heads=HQ, heads=HQ,
seq_len=SEQ_LEN, seq_len=SEQ_LEN,
...@@ -140,7 +141,6 @@ def main(): ...@@ -140,7 +141,6 @@ def main():
selected_blocks=S, selected_blocks=S,
) )
kernel = tilelang.compile(program, out_idx=-1)
Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
......
...@@ -8,6 +8,7 @@ import tilelang.testing ...@@ -8,6 +8,7 @@ import tilelang.testing
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
@tilelang.jit(out_idx=[-1])
def native_sparse_attention(batch, def native_sparse_attention(batch,
heads, heads,
seq_len, seq_len,
...@@ -128,7 +129,7 @@ def native_sparse_attention(batch, ...@@ -128,7 +129,7 @@ def native_sparse_attention(batch,
def main(): def main():
B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1
program = native_sparse_attention( kernel = native_sparse_attention(
batch=B, batch=B,
heads=HQ, heads=HQ,
seq_len=SEQ_LEN, seq_len=SEQ_LEN,
...@@ -139,7 +140,6 @@ def main(): ...@@ -139,7 +140,6 @@ def main():
selected_blocks=S, selected_blocks=S,
scale=scale, scale=scale,
) )
kernel = tilelang.compile(program, out_idx=-1)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
torch.random.manual_seed(0) torch.random.manual_seed(0)
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
......
...@@ -16,6 +16,7 @@ from reference import naive_nsa ...@@ -16,6 +16,7 @@ from reference import naive_nsa
from einops import rearrange from einops import rearrange
@tilelang.jit
def native_sparse_attention_varlen(batch, def native_sparse_attention_varlen(batch,
heads, heads,
c_seq_len, c_seq_len,
...@@ -171,7 +172,7 @@ def parallel_nsa_fwd( ...@@ -171,7 +172,7 @@ def parallel_nsa_fwd(
BS = block_size BS = block_size
WS = window_size WS = window_size
program = native_sparse_attention_varlen( kernel = native_sparse_attention_varlen(
batch=batch, batch=batch,
heads=HQ, heads=HQ,
c_seq_len=C_SEQ_LEN, c_seq_len=C_SEQ_LEN,
...@@ -182,8 +183,6 @@ def parallel_nsa_fwd( ...@@ -182,8 +183,6 @@ def parallel_nsa_fwd(
selected_blocks=S, selected_blocks=S,
) )
kernel = tilelang.compile(program)
o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device)
kernel( kernel(
q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D), q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D),
......
...@@ -8,6 +8,7 @@ import tilelang.language as T ...@@ -8,6 +8,7 @@ import tilelang.language as T
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
@tilelang.jit(out_idx=[2])
def matmul( def matmul(
M, M,
N, N,
...@@ -98,7 +99,7 @@ def run_gemm( ...@@ -98,7 +99,7 @@ def run_gemm(
num_stages=3, num_stages=3,
num_threads=128, num_threads=128,
): ):
program = matmul( kernel = matmul(
M, M,
N, N,
K, K,
...@@ -112,7 +113,6 @@ def run_gemm( ...@@ -112,7 +113,6 @@ def run_gemm(
num_threads, num_threads,
) )
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
out = profiler.run_once() out = profiler.run_once()
...@@ -435,7 +435,6 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): ...@@ -435,7 +435,6 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
def main(): def main():
test_run_dequantize_gemm() test_run_dequantize_gemm()
test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -54,6 +54,7 @@ def torch_convert(tensor): ...@@ -54,6 +54,7 @@ def torch_convert(tensor):
return new_tensor return new_tensor
@tilelang.jit(out_idx=[1])
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = "uint8"
...@@ -89,7 +90,7 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): ...@@ -89,7 +90,7 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
def test_fp4_fp16_convert_close(): def test_fp4_fp16_convert_close():
N, K = 256, 256 N, K = 256, 256
block_N, block_K = 64, 64 block_N, block_K = 64, 64
program = test_convert( kernel = test_convert(
N, N,
K, K,
block_N, block_N,
...@@ -97,8 +98,6 @@ def test_fp4_fp16_convert_close(): ...@@ -97,8 +98,6 @@ def test_fp4_fp16_convert_close():
"float16", "float16",
) )
kernel = tilelang.compile(program, out_idx=[1])
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = kernel(B) tl_out = kernel(B)
ref_out = torch_convert(B) ref_out = torch_convert(B)
...@@ -128,6 +127,7 @@ def get_configs(): ...@@ -128,6 +127,7 @@ def get_configs():
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = "uint8"
...@@ -270,10 +270,9 @@ def main(m=256, n=256, k=256, tune=False): ...@@ -270,10 +270,9 @@ def main(m=256, n=256, k=256, tune=False):
total_flops = 2 * m * n * k total_flops = 2 * m * n * k
if (not tune): if (not tune):
program = matmul( kernel = matmul(
m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
......
...@@ -7,6 +7,7 @@ from tilelang.quantize import ( ...@@ -7,6 +7,7 @@ from tilelang.quantize import (
_tir_packed_int_to_int_convert,) _tir_packed_int_to_int_convert,)
@tilelang.jit
def dequantize_gemv( def dequantize_gemv(
M: int, M: int,
N: int, N: int,
...@@ -173,11 +174,9 @@ def main() -> None: ...@@ -173,11 +174,9 @@ def main() -> None:
group_size = -1 group_size = -1
with_scaling = False with_scaling = False
program = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype, kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype,
source_format, n_partition, reduce_thread, fast_decoding, trans_A, source_format, n_partition, reduce_thread, fast_decoding, trans_A,
trans_B, group_size, with_scaling) trans_B, group_size, with_scaling)
kernel = tilelang.compile(program)
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits num_elems_per_byte = storage_nbit // num_bits
......
...@@ -7,6 +7,7 @@ tilelang.testing.set_random_seed(0) ...@@ -7,6 +7,7 @@ tilelang.testing.set_random_seed(0)
tilelang.disable_cache() tilelang.disable_cache()
@tilelang.jit(pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8})
def matmul_dynamic_mnk( def matmul_dynamic_mnk(
block_M, block_M,
block_N, block_N,
...@@ -60,14 +61,8 @@ def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtyp ...@@ -60,14 +61,8 @@ def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtyp
print( print(
f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}" f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}"
) )
program = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads) accum_dtype, num_stages, threads)
kernel = tilelang.compile(
program, pass_configs={
"tl.disable_dynamic_tail_split": True,
"tl.dynamic_alignment": 8
})
import torch import torch
if trans_A: if trans_A:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment