Commit 25a50f1a authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] refactor `tilelang.jit` to support a faster and more flexible kernel cache (#501)

* [Refactor] Update JIT kernel functions and streamline GEMM tests

* Renamed and refactored matmul and run_gemm functions to matmul_kernel_jit and run_gemm_kernel_jit for clarity.
* Removed redundant JIT decorator from the matmul function, ensuring it is applied only to the kernel function.
* Updated test function names to reflect changes in the kernel functions, enhancing consistency and readability.
* Cleaned up commented-out code and unnecessary imports to improve overall code quality.

* Update main function call in GEMM test to use tilelang testing framework

* Update README and example scripts to include JIT decorator comments

* Added comments in README.md and various example scripts to indicate the use of the @tilelang.jit decorator for returning torch functions.
* Removed redundant comments that previously instructed to add the decorator, streamlining the documentation and improving clarity.

* Update GEMM test parameters for improved performance

* Set num_stages to 0 and adjusted matrix dimensions in test functions to enhance performance and consistency across GEMM tests in test_tilelang_kernel_gemm.py.
parent 33937683
...@@ -126,8 +126,10 @@ import tilelang.language as T ...@@ -126,8 +126,10 @@ import tilelang.language as T
from tilelang.intrinsics import ( from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,) make_mma_swizzle_layout as make_swizzle_layout,)
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -4,7 +4,6 @@ from typing import Optional, Union ...@@ -4,7 +4,6 @@ from typing import Optional, Union
import torch import torch
import triton import triton
import triton.language as tl
from fla.ops.common.utils import prepare_token_indices from fla.ops.common.utils import prepare_token_indices
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
...@@ -13,6 +12,7 @@ from einops import rearrange ...@@ -13,6 +12,7 @@ from einops import rearrange
import tilelang import tilelang
@tilelang.jit
def tilelang_kernel_fwd( def tilelang_kernel_fwd(
batch, batch,
heads, heads,
...@@ -55,7 +55,6 @@ def tilelang_kernel_fwd( ...@@ -55,7 +55,6 @@ def tilelang_kernel_fwd(
num_stages = 0 num_stages = 0
threads = 32 threads = 32
@tilelang.jit
@T.prim_func @T.prim_func
def native_sparse_attention( def native_sparse_attention(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
...@@ -146,6 +145,7 @@ def tilelang_kernel_fwd( ...@@ -146,6 +145,7 @@ def tilelang_kernel_fwd(
return native_sparse_attention return native_sparse_attention
@tilelang.jit
def tilelang_kernel_bwd_dkv( def tilelang_kernel_bwd_dkv(
batch, batch,
heads, heads,
...@@ -193,7 +193,6 @@ def tilelang_kernel_bwd_dkv( ...@@ -193,7 +193,6 @@ def tilelang_kernel_bwd_dkv(
num_threads = 32 num_threads = 32
print("NV", NV, "NS", NS, "B", B, "H", H) print("NV", NV, "NS", NS, "B", B, "H", H)
@tilelang.jit
@T.prim_func @T.prim_func
def flash_bwd_dkv( def flash_bwd_dkv(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
...@@ -310,6 +309,7 @@ def make_dq_layout(dQ): ...@@ -310,6 +309,7 @@ def make_dq_layout(dQ):
) )
@tilelang.jit
def tilelang_kernel_bwd_dqkv( def tilelang_kernel_bwd_dqkv(
batch, batch,
heads, heads,
...@@ -357,7 +357,6 @@ def tilelang_kernel_bwd_dqkv( ...@@ -357,7 +357,6 @@ def tilelang_kernel_bwd_dqkv(
block_mask_shape = [batch, seq_len, heads_kv, NS] block_mask_shape = [batch, seq_len, heads_kv, NS]
num_threads = 32 num_threads = 32
@tilelang.jit
@T.prim_func @T.prim_func
def flash_bwd_dqkv( def flash_bwd_dqkv(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
...@@ -473,6 +472,7 @@ def tilelang_kernel_bwd_dqkv( ...@@ -473,6 +472,7 @@ def tilelang_kernel_bwd_dqkv(
return flash_bwd_dqkv return flash_bwd_dqkv
@tilelang.jit(out_idx=[2])
def tilelang_kernel_preprocess( def tilelang_kernel_preprocess(
batch, batch,
heads, heads,
...@@ -486,7 +486,6 @@ def tilelang_kernel_preprocess( ...@@ -486,7 +486,6 @@ def tilelang_kernel_preprocess(
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
@tilelang.jit(out_idx=[2], execution_backend="cython")
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
...@@ -510,6 +509,7 @@ def tilelang_kernel_preprocess( ...@@ -510,6 +509,7 @@ def tilelang_kernel_preprocess(
return flash_bwd_prep return flash_bwd_prep
@tilelang.jit(out_idx=[2])
def tilelang_kernel_block_mask( def tilelang_kernel_block_mask(
batch, batch,
heads, heads,
...@@ -529,7 +529,6 @@ def tilelang_kernel_block_mask( ...@@ -529,7 +529,6 @@ def tilelang_kernel_block_mask(
block_mask_shape = [batch, seq_len, heads, NS] block_mask_shape = [batch, seq_len, heads, NS]
USE_BLOCK_COUNTS = block_counts is not None USE_BLOCK_COUNTS = block_counts is not None
@tilelang.jit(out_idx=[2], execution_backend="cython")
@T.prim_func @T.prim_func
def flash_bwd_block_mask( def flash_bwd_block_mask(
BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore
......
...@@ -8,8 +8,10 @@ from tilelang.intrinsics import ( ...@@ -8,8 +8,10 @@ from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,) # noqa: F401 make_mma_swizzle_layout as make_swizzle_layout,) # noqa: F401
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -2,11 +2,13 @@ import tilelang ...@@ -2,11 +2,13 @@ import tilelang
import tilelang.language as T import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
num_stages = 2 num_stages = 2
mbarrier_list = [128, 128] * num_stages mbarrier_list = [128, 128] * num_stages
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor[(M, K), dtype], A: T.Tensor[(M, K), dtype],
......
...@@ -2,6 +2,8 @@ import tilelang ...@@ -2,6 +2,8 @@ import tilelang
import tilelang.language as T import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul_warp_specialize_copy_0_gemm_1(M, def matmul_warp_specialize_copy_0_gemm_1(M,
N, N,
K, K,
...@@ -10,7 +12,7 @@ def matmul_warp_specialize_copy_0_gemm_1(M, ...@@ -10,7 +12,7 @@ def matmul_warp_specialize_copy_0_gemm_1(M,
block_K, block_K,
dtype="float16", dtype="float16",
accum_dtype="float"): accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -2,6 +2,8 @@ import tilelang ...@@ -2,6 +2,8 @@ import tilelang
import tilelang.language as T import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul_warp_specialize_copy_1_gemm_0(M, def matmul_warp_specialize_copy_1_gemm_0(M,
N, N,
K, K,
...@@ -10,7 +12,7 @@ def matmul_warp_specialize_copy_1_gemm_0(M, ...@@ -10,7 +12,7 @@ def matmul_warp_specialize_copy_1_gemm_0(M,
block_K, block_K,
dtype="float16", dtype="float16",
accum_dtype="float"): accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -4,6 +4,8 @@ import tilelang.language as T ...@@ -4,6 +4,8 @@ import tilelang.language as T
tilelang.disable_cache() tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul_warp_specialize_copy_1_gemm_0(M, def matmul_warp_specialize_copy_1_gemm_0(M,
N, N,
K, K,
...@@ -15,7 +17,7 @@ def matmul_warp_specialize_copy_1_gemm_0(M, ...@@ -15,7 +17,7 @@ def matmul_warp_specialize_copy_1_gemm_0(M,
warp_group_num = 2 warp_group_num = 2
threads = 128 * warp_group_num threads = 128 * warp_group_num
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -2,8 +2,10 @@ import tilelang ...@@ -2,8 +2,10 @@ import tilelang
import tilelang.language as T import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor[(M, K), dtype], A: T.Tensor[(M, K), dtype],
......
...@@ -4,7 +4,10 @@ import tilelang ...@@ -4,7 +4,10 @@ import tilelang
import torch import torch
def matmul( @tilelang.jit(
out_idx=-1, # create the output tensor during runtime
)
def matmul_kernel_jit(
M, M,
N, N,
K, K,
...@@ -26,9 +29,6 @@ def matmul( ...@@ -26,9 +29,6 @@ def matmul(
import tilelang.language as T import tilelang.language as T
@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
)
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
...@@ -55,7 +55,7 @@ def matmul( ...@@ -55,7 +55,7 @@ def matmul(
return main return main
def run_gemm( def run_gemm_kernel_jit(
M, M,
N, N,
K, K,
...@@ -70,7 +70,7 @@ def run_gemm( ...@@ -70,7 +70,7 @@ def run_gemm(
num_stages=3, num_stages=3,
num_threads=128, num_threads=128,
): ):
matmul_kernel = matmul( matmul_kernel = matmul_kernel_jit(
M, M,
N, N,
K, K,
...@@ -106,126 +106,8 @@ def run_gemm( ...@@ -106,126 +106,8 @@ def run_gemm(
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_f16f16f16_nn(): def test_gemm_f16f16f16_nn_kernel_jit():
run_gemm( run_gemm_kernel_jit(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, out_idx=-1)
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
A = A.T
if trans_B:
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
ref_C = ref_program(A, B)
C = matmul_kernel(A, B)
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_jit_kernel():
run_gemm_jit_kernel(
512, 512,
1024, 1024,
768, 768,
......
...@@ -62,7 +62,7 @@ def run_gemm( ...@@ -62,7 +62,7 @@ def run_gemm(
block_M, block_M,
block_N, block_N,
block_K, block_K,
num_stages=3, num_stages=0,
num_threads=128, num_threads=128,
): ):
program = matmul( program = matmul(
...@@ -109,9 +109,9 @@ def test_gemm_f16f16f16_nn(): ...@@ -109,9 +109,9 @@ def test_gemm_f16f16f16_nn():
"float16", "float16",
"float16", "float16",
128, 128,
256, 128,
32, 32,
2, 0,
) )
...@@ -174,9 +174,9 @@ def test_gemm_f16f16f16_tn(): ...@@ -174,9 +174,9 @@ def test_gemm_f16f16f16_tn():
"float16", "float16",
"float16", "float16",
128, 128,
256, 128,
32, 32,
2, 0,
) )
...@@ -191,9 +191,9 @@ def test_gemm_f16f16f16_nt(): ...@@ -191,9 +191,9 @@ def test_gemm_f16f16f16_nt():
"float16", "float16",
"float16", "float16",
128, 128,
256, 128,
32, 32,
2, 0,
) )
...@@ -401,9 +401,9 @@ def test_gemm_f16f16f16_sr(): ...@@ -401,9 +401,9 @@ def test_gemm_f16f16f16_sr():
"float16", "float16",
"float16", "float16",
128, 128,
256, 128,
32, 32,
2, 0,
) )
......
...@@ -2,8 +2,10 @@ import tilelang ...@@ -2,8 +2,10 @@ import tilelang
import tilelang.language as T import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -2,8 +2,10 @@ import tilelang ...@@ -2,8 +2,10 @@ import tilelang
import tilelang.language as T import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -3,8 +3,10 @@ import tilelang.language as T ...@@ -3,8 +3,10 @@ import tilelang.language as T
import torch import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, dtype="float16"): def tilelang_copy(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
......
...@@ -3,8 +3,10 @@ import tilelang.language as T ...@@ -3,8 +3,10 @@ import tilelang.language as T
import torch import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -44,8 +46,10 @@ def test_tilelang_copy_mask_parallel(): ...@@ -44,8 +46,10 @@ def test_tilelang_copy_mask_parallel():
run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128) run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128)
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -84,8 +88,10 @@ def test_tilelang_copy_mask_copy(): ...@@ -84,8 +88,10 @@ def test_tilelang_copy_mask_copy():
run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128) run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128)
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -129,8 +135,10 @@ def test_tilelang_copy_mask_parallel_range(): ...@@ -129,8 +135,10 @@ def test_tilelang_copy_mask_parallel_range():
run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128) run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128)
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
......
...@@ -4,109 +4,32 @@ It includes functionality to JIT-compile TileLang programs into a runnable ...@@ -4,109 +4,32 @@ It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM. kernel adapter using TVM.
""" """
from typing import Callable, List, Literal, Union, Any, Optional, Dict from typing import (
Any,
List,
Union,
Callable,
Tuple,
TypeVar,
overload,
Literal,
Dict, # For type hinting dicts
Optional,
)
from typing_extensions import ParamSpec
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tvm.target import Target from tvm.target import Target
from tilelang.jit.adapter import BaseKernelAdapter
from tilelang.jit.kernel import JITKernel from tilelang.jit.kernel import JITKernel
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
from tilelang.cache import cached from tilelang.cache import cached
from os import path, makedirs
from logging import getLogger from logging import getLogger
import functools
logger = getLogger(__name__) logger = getLogger(__name__)
def jit(
func: Callable = None,
*, # Enforce keyword-only arguments from here on
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target: Union[str, Target] = "auto",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
) -> BaseKernelAdapter:
"""
A decorator (or decorator factory) that JIT-compiles a given TileLang PrimFunc
into a runnable kernel adapter using TVM. If called with arguments, it returns
a decorator that can be applied to a function. If called without arguments,
it directly compiles the given function.
Parameters
----------
func : Callable, optional
The TileLang PrimFunc to JIT-compile. If None, this function returns a
decorator that expects a TileLang PrimFunc.
out_idx : Union[List[int], int], optional
The index (or list of indices) of the function outputs. This can be used
to specify which outputs from the compiled function will be returned.
execution_backend : Literal["dlpack", "ctypes"], optional
The wrapper type to use for the kernel adapter. Currently, only "dlpack"
and "ctypes" are supported.
target : Union[str, Target], optional
The compilation target for TVM. If set to "auto", an appropriate target
will be inferred automatically. Otherwise, must be one of the supported
strings in AVALIABLE_TARGETS or a TVM Target instance.
Returns
-------
BaseKernelAdapter
An adapter object that encapsulates the compiled function and can be
used to execute it.
Raises
------
AssertionError
If the provided target is an invalid string not present in AVALIABLE_TARGETS.
"""
# If the target is specified as a string, ensure it is valid and convert to a TVM Target.
if isinstance(target, str):
assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
target = determine_target(target)
target = Target(target)
assert execution_backend in ["dlpack", "ctypes", "cython"], "Invalid execution backend."
def _compile_and_create_adapter(tilelang_func: PrimFunc) -> BaseKernelAdapter:
"""
Compile the given TileLang PrimFunc with TVM and build a kernel adapter.
Parameters
----------
tilelang_func : tvm.tir.PrimFunc
The TileLang (TVM TIR) function to compile.
Returns
-------
BaseKernelAdapter
The compiled and ready-to-run kernel adapter.
"""
if verbose:
logger.info(f"Compiling TileLang function:\n{tilelang_func}")
return compile(
tilelang_func,
target=target,
verbose=verbose,
execution_backend=execution_backend,
out_idx=out_idx,
pass_configs=pass_configs,
).adapter
# If `func` was given, compile it immediately and return the adapter.
if func is not None:
return _compile_and_create_adapter(func)
# Otherwise, return a decorator that expects a function to compile.
def real_decorator(tilelang_func: PrimFunc) -> BaseKernelAdapter:
return _compile_and_create_adapter(tilelang_func)
return real_decorator
def compile( def compile(
func: PrimFunc = None, func: PrimFunc = None,
out_idx: Union[List[int], int, None] = None, out_idx: Union[List[int], int, None] = None,
...@@ -151,4 +74,219 @@ def compile( ...@@ -151,4 +74,219 @@ def compile(
target_host=target_host, target_host=target_host,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
) )
\ No newline at end of file
# --- Mocking dependencies for the example to run ---
# In your actual code, these would be your real types.
class Program:
"""Placeholder for the type returned by the original decorated function."""
def __init__(self, data: str):
self.data = data
def __repr__(self):
return f"Program('{self.data}')"
class Kernel:
"""Placeholder for the type of the compiled kernel."""
def __init__(self, source: str, out_idx: Any):
self.source_code = source
self.out_idx = out_idx
def get_kernel_source(self) -> str:
return self.source_code
def __repr__(self):
return f"Kernel('{self.source_code[:20]}...')"
# --- End Mocking ---
# P (Parameters) captures the argument types of the decorated function.
_P = ParamSpec("_P")
# R_prog (Return type of Program) captures the return type of the original decorated function.
# We assume the original function returns something compatible with 'Program'.
_RProg = TypeVar("_RProg", bound=Program)
class jit:
# Overload __init__ to help type checkers understand the effect of return_program
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
@overload
def __init__(self,
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
*,
return_program: Literal[True]) -> None:
...
@overload
def __init__(self,
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
*,
return_program: Literal[False] = False) -> None:
...
# Actual implementation of __init__
def __init__(self,
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
*,
return_program: bool = False):
"""
Initializes the JIT compiler decorator.
Parameters
----------
out_idx : Any, optional
Index(es) of the output tensors to return from the compiled kernel
(default: None, meaning all outputs are returned or determined by the kernel itself).
target : Union[str, Target], optional
Compilation target for TVM. Can be a string (e.g., "cuda", "llvm")
or a TVM Target object. If "auto", the target is determined automatically
(default: "auto").
target_host : Union[str, Target], optional
Target host for cross-compilation, similar to `target` (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
The backend used for kernel execution and argument passing.
"dlpack" is generally preferred for zero-copy tensor passing with compatible frameworks.
"ctypes" uses standard C types. "cython" uses Cython for potentially faster execution.
(default: "cython").
verbose : bool, optional
If True, enables verbose logging during compilation (default: False).
pass_configs : Optional[Dict[str, Any]], optional
A dictionary of configurations for TVM's pass context. These can fine-tune
the compilation process. Examples include "tir.disable_vectorize"
(default: None).
debug_root_path : Optional[str], optional
If provided, the compiled kernel's source code will be saved to a file
in this directory. This is useful for debugging the generated code.
If None, no debug information is saved (default: None).
If a relative path is given, it's made absolute relative to the project root
or current working directory.
return_program : bool, optional
If True, the decorated function will return a tuple containing the
original program's result and the compiled kernel. If False, only the
compiled kernel is returned (default: False).
"""
if debug_root_path is None:
# This logic was previously under 'if debug and debug_root_path is None:'
# Now, if debug_root_path is explicitly None, we don't try to set a default path.
# If a user wants debugging, they must provide a path.
pass
elif not path.isabs(debug_root_path): # If a relative path is given, make it absolute
try:
# This assumes the file is part of a typical project structure
base_path = path.dirname(path.dirname(path.dirname(__file__)))
debug_root_path = path.join(base_path, debug_root_path)
except NameError: # __file__ is not defined (e.g., in a REPL or notebook)
# Fallback to making it absolute based on current working directory if __file__ fails
debug_root_path = path.abspath(debug_root_path)
self.out_idx = out_idx
self.execution_backend = execution_backend
self.target = target
self.target_host = target_host
self.verbose = verbose
self.pass_configs = pass_configs
self.debug_root_path: Optional[str] = debug_root_path
self.return_program: bool = return_program
# Type hint the caches
self._program_cache: Dict[tuple, _RProg] = {}
self._kernel_cache: Dict[tuple, Kernel] = {}
# Overload __call__ based on the value of self.return_program
# This tells the type checker what the *wrapper* function will return.
# The wrapper will take the same parameters P as the original function.
# Case 1: return_program is True
@overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]:
# This signature is chosen by the type checker if self.return_program is True
# (inferred from the __init__ call).
...
# Case 2: return_program is False (or not specified, defaulting to False)
@overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]:
# This signature is chosen if self.return_program is False.
...
# Actual implementation of __call__
def __call__(
self, func: Union[Callable[_P, _RProg], PrimFunc]
) -> Callable[_P, Any]: # Any for implementation flexibility
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: # Use _P.args and _P.kwargs
# Create a hashable key. args is already a tuple.
# For kwargs, convert to a sorted tuple of items to ensure consistent ordering.
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
# Check if both program and kernel are cached.
# If program is not cached, we'll recompute both.
# (The original check 'key not in self._program_cache or key not in self._kernel_cache'
# implies that if either is missing, both are recomputed and stored.
# A simpler 'key not in self._program_cache' would often suffice if they are always
# added together.)
if key not in self._program_cache: # Assuming if program isn't there, kernel isn't either or needs refresh
if isinstance(func, PrimFunc):
program_result = func
elif isinstance(func, Callable):
program_result = func(*args, **kwargs)
else:
raise ValueError(f"Invalid function type: {type(func)}")
kernel_result = compile(
program_result,
out_idx=self.out_idx,
execution_backend=self.execution_backend,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs,
)
if self.debug_root_path: # Check if a path is provided
func_name = func.__name__
kernel_file = f'tilelang_jit_kernel_{func_name}.c'
# Ensure the debug directory exists
makedirs(self.debug_root_path, exist_ok=True)
with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
print(kernel_result.get_kernel_source(), file=f)
self._program_cache[key] = program_result
self._kernel_cache[key] = kernel_result
# Retrieve from cache (even if just populated)
cached_program = self._program_cache[key]
cached_kernel = self._kernel_cache[key]
if self.return_program:
return cached_program, cached_kernel
else:
return cached_kernel
return wrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment