"driver/driver.hip.cpp" did not exist on "5e77650415119376712fbe7aefb2f3922af021db"
Commit 25a50f1a authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] refactor `tilelang.jit` to support a faster and more flexible kernel cache (#501)

* [Refactor] Update JIT kernel functions and streamline GEMM tests

* Renamed and refactored matmul and run_gemm functions to matmul_kernel_jit and run_gemm_kernel_jit for clarity.
* Removed redundant JIT decorator from the matmul function, ensuring it is applied only to the kernel function.
* Updated test function names to reflect changes in the kernel functions, enhancing consistency and readability.
* Cleaned up commented-out code and unnecessary imports to improve overall code quality.

* Update main function call in GEMM test to use tilelang testing framework

* Update README and example scripts to include JIT decorator comments

* Added comments in README.md and various example scripts to indicate the use of the @tilelang.jit decorator for returning torch functions.
* Removed redundant comments that previously instructed to add the decorator, streamlining the documentation and improving clarity.

* Update GEMM test parameters for improved performance

* Set num_stages to 0 and adjusted matrix dimensions in test functions to enhance performance and consistency across GEMM tests in test_tilelang_kernel_gemm.py.
parent 33937683
......@@ -126,8 +126,10 @@ import tilelang.language as T
from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,)
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -4,7 +4,6 @@ from typing import Optional, Union
import torch
import triton
import triton.language as tl
from fla.ops.common.utils import prepare_token_indices
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
......@@ -13,6 +12,7 @@ from einops import rearrange
import tilelang
@tilelang.jit
def tilelang_kernel_fwd(
batch,
heads,
......@@ -55,7 +55,6 @@ def tilelang_kernel_fwd(
num_stages = 0
threads = 32
@tilelang.jit
@T.prim_func
def native_sparse_attention(
Q: T.Tensor(q_shape, dtype),
......@@ -146,6 +145,7 @@ def tilelang_kernel_fwd(
return native_sparse_attention
@tilelang.jit
def tilelang_kernel_bwd_dkv(
batch,
heads,
......@@ -193,7 +193,6 @@ def tilelang_kernel_bwd_dkv(
num_threads = 32
print("NV", NV, "NS", NS, "B", B, "H", H)
@tilelang.jit
@T.prim_func
def flash_bwd_dkv(
Q: T.Tensor(q_shape, dtype),
......@@ -310,6 +309,7 @@ def make_dq_layout(dQ):
)
@tilelang.jit
def tilelang_kernel_bwd_dqkv(
batch,
heads,
......@@ -357,7 +357,6 @@ def tilelang_kernel_bwd_dqkv(
block_mask_shape = [batch, seq_len, heads_kv, NS]
num_threads = 32
@tilelang.jit
@T.prim_func
def flash_bwd_dqkv(
Q: T.Tensor(q_shape, dtype),
......@@ -473,6 +472,7 @@ def tilelang_kernel_bwd_dqkv(
return flash_bwd_dqkv
@tilelang.jit(out_idx=[2])
def tilelang_kernel_preprocess(
batch,
heads,
......@@ -486,7 +486,6 @@ def tilelang_kernel_preprocess(
shape = [batch, seq_len, heads, dim]
@tilelang.jit(out_idx=[2], execution_backend="cython")
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
......@@ -510,6 +509,7 @@ def tilelang_kernel_preprocess(
return flash_bwd_prep
@tilelang.jit(out_idx=[2])
def tilelang_kernel_block_mask(
batch,
heads,
......@@ -529,7 +529,6 @@ def tilelang_kernel_block_mask(
block_mask_shape = [batch, seq_len, heads, NS]
USE_BLOCK_COUNTS = block_counts is not None
@tilelang.jit(out_idx=[2], execution_backend="cython")
@T.prim_func
def flash_bwd_block_mask(
BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore
......
......@@ -8,8 +8,10 @@ from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,) # noqa: F401
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -2,11 +2,13 @@ import tilelang
import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
num_stages = 2
mbarrier_list = [128, 128] * num_stages
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor[(M, K), dtype],
......
......@@ -2,6 +2,8 @@ import tilelang
import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul_warp_specialize_copy_0_gemm_1(M,
N,
K,
......@@ -10,7 +12,7 @@ def matmul_warp_specialize_copy_0_gemm_1(M,
block_K,
dtype="float16",
accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -2,6 +2,8 @@ import tilelang
import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul_warp_specialize_copy_1_gemm_0(M,
N,
K,
......@@ -10,7 +12,7 @@ def matmul_warp_specialize_copy_1_gemm_0(M,
block_K,
dtype="float16",
accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -4,6 +4,8 @@ import tilelang.language as T
tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul_warp_specialize_copy_1_gemm_0(M,
N,
K,
......@@ -15,7 +17,7 @@ def matmul_warp_specialize_copy_1_gemm_0(M,
warp_group_num = 2
threads = 128 * warp_group_num
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -2,8 +2,10 @@ import tilelang
import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor[(M, K), dtype],
......
......@@ -4,7 +4,10 @@ import tilelang
import torch
def matmul(
@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
)
def matmul_kernel_jit(
M,
N,
K,
......@@ -26,9 +29,6 @@ def matmul(
import tilelang.language as T
@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -55,7 +55,7 @@ def matmul(
return main
def run_gemm(
def run_gemm_kernel_jit(
M,
N,
K,
......@@ -70,7 +70,7 @@ def run_gemm(
num_stages=3,
num_threads=128,
):
matmul_kernel = matmul(
matmul_kernel = matmul_kernel_jit(
M,
N,
K,
......@@ -106,126 +106,8 @@ def run_gemm(
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, out_idx=-1)
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
A = A.T
if trans_B:
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
ref_C = ref_program(A, B)
C = matmul_kernel(A, B)
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_jit_kernel():
run_gemm_jit_kernel(
def test_gemm_f16f16f16_nn_kernel_jit():
run_gemm_kernel_jit(
512,
1024,
768,
......
......@@ -62,7 +62,7 @@ def run_gemm(
block_M,
block_N,
block_K,
num_stages=3,
num_stages=0,
num_threads=128,
):
program = matmul(
......@@ -109,9 +109,9 @@ def test_gemm_f16f16f16_nn():
"float16",
"float16",
128,
256,
128,
32,
2,
0,
)
......@@ -174,9 +174,9 @@ def test_gemm_f16f16f16_tn():
"float16",
"float16",
128,
256,
128,
32,
2,
0,
)
......@@ -191,9 +191,9 @@ def test_gemm_f16f16f16_nt():
"float16",
"float16",
128,
256,
128,
32,
2,
0,
)
......@@ -401,9 +401,9 @@ def test_gemm_f16f16f16_sr():
"float16",
"float16",
128,
256,
128,
32,
2,
0,
)
......
......@@ -2,8 +2,10 @@ import tilelang
import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -2,8 +2,10 @@ import tilelang
import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -3,8 +3,10 @@ import tilelang.language as T
import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......
......@@ -3,8 +3,10 @@ import tilelang.language as T
import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......@@ -44,8 +46,10 @@ def test_tilelang_copy_mask_parallel():
run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128)
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......@@ -84,8 +88,10 @@ def test_tilelang_copy_mask_copy():
run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128)
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......@@ -129,8 +135,10 @@ def test_tilelang_copy_mask_parallel_range():
run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128)
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......
......@@ -4,109 +4,32 @@ It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM.
"""
from typing import Callable, List, Literal, Union, Any, Optional, Dict
from typing import (
Any,
List,
Union,
Callable,
Tuple,
TypeVar,
overload,
Literal,
Dict, # For type hinting dicts
Optional,
)
from typing_extensions import ParamSpec
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target
from tilelang.jit.adapter import BaseKernelAdapter
from tilelang.jit.kernel import JITKernel
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
from tilelang.cache import cached
from os import path, makedirs
from logging import getLogger
import functools
logger = getLogger(__name__)
def jit(
func: Callable = None,
*, # Enforce keyword-only arguments from here on
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target: Union[str, Target] = "auto",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
) -> BaseKernelAdapter:
"""
A decorator (or decorator factory) that JIT-compiles a given TileLang PrimFunc
into a runnable kernel adapter using TVM. If called with arguments, it returns
a decorator that can be applied to a function. If called without arguments,
it directly compiles the given function.
Parameters
----------
func : Callable, optional
The TileLang PrimFunc to JIT-compile. If None, this function returns a
decorator that expects a TileLang PrimFunc.
out_idx : Union[List[int], int], optional
The index (or list of indices) of the function outputs. This can be used
to specify which outputs from the compiled function will be returned.
execution_backend : Literal["dlpack", "ctypes"], optional
The wrapper type to use for the kernel adapter. Currently, only "dlpack"
and "ctypes" are supported.
target : Union[str, Target], optional
The compilation target for TVM. If set to "auto", an appropriate target
will be inferred automatically. Otherwise, must be one of the supported
strings in AVALIABLE_TARGETS or a TVM Target instance.
Returns
-------
BaseKernelAdapter
An adapter object that encapsulates the compiled function and can be
used to execute it.
Raises
------
AssertionError
If the provided target is an invalid string not present in AVALIABLE_TARGETS.
"""
# If the target is specified as a string, ensure it is valid and convert to a TVM Target.
if isinstance(target, str):
assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
target = determine_target(target)
target = Target(target)
assert execution_backend in ["dlpack", "ctypes", "cython"], "Invalid execution backend."
def _compile_and_create_adapter(tilelang_func: PrimFunc) -> BaseKernelAdapter:
"""
Compile the given TileLang PrimFunc with TVM and build a kernel adapter.
Parameters
----------
tilelang_func : tvm.tir.PrimFunc
The TileLang (TVM TIR) function to compile.
Returns
-------
BaseKernelAdapter
The compiled and ready-to-run kernel adapter.
"""
if verbose:
logger.info(f"Compiling TileLang function:\n{tilelang_func}")
return compile(
tilelang_func,
target=target,
verbose=verbose,
execution_backend=execution_backend,
out_idx=out_idx,
pass_configs=pass_configs,
).adapter
# If `func` was given, compile it immediately and return the adapter.
if func is not None:
return _compile_and_create_adapter(func)
# Otherwise, return a decorator that expects a function to compile.
def real_decorator(tilelang_func: PrimFunc) -> BaseKernelAdapter:
return _compile_and_create_adapter(tilelang_func)
return real_decorator
def compile(
func: PrimFunc = None,
out_idx: Union[List[int], int, None] = None,
......@@ -151,4 +74,219 @@ def compile(
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
)
\ No newline at end of file
)
# --- Mocking dependencies for the example to run ---
# In your actual code, these would be your real types.
class Program:
"""Placeholder for the type returned by the original decorated function."""
def __init__(self, data: str):
self.data = data
def __repr__(self):
return f"Program('{self.data}')"
class Kernel:
"""Placeholder for the type of the compiled kernel."""
def __init__(self, source: str, out_idx: Any):
self.source_code = source
self.out_idx = out_idx
def get_kernel_source(self) -> str:
return self.source_code
def __repr__(self):
return f"Kernel('{self.source_code[:20]}...')"
# --- End Mocking ---
# P (Parameters) captures the argument types of the decorated function.
_P = ParamSpec("_P")
# R_prog (Return type of Program) captures the return type of the original decorated function.
# We assume the original function returns something compatible with 'Program'.
_RProg = TypeVar("_RProg", bound=Program)
class jit:
# Overload __init__ to help type checkers understand the effect of return_program
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
@overload
def __init__(self,
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
*,
return_program: Literal[True]) -> None:
...
@overload
def __init__(self,
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
*,
return_program: Literal[False] = False) -> None:
...
# Actual implementation of __init__
def __init__(self,
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
*,
return_program: bool = False):
"""
Initializes the JIT compiler decorator.
Parameters
----------
out_idx : Any, optional
Index(es) of the output tensors to return from the compiled kernel
(default: None, meaning all outputs are returned or determined by the kernel itself).
target : Union[str, Target], optional
Compilation target for TVM. Can be a string (e.g., "cuda", "llvm")
or a TVM Target object. If "auto", the target is determined automatically
(default: "auto").
target_host : Union[str, Target], optional
Target host for cross-compilation, similar to `target` (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
The backend used for kernel execution and argument passing.
"dlpack" is generally preferred for zero-copy tensor passing with compatible frameworks.
"ctypes" uses standard C types. "cython" uses Cython for potentially faster execution.
(default: "cython").
verbose : bool, optional
If True, enables verbose logging during compilation (default: False).
pass_configs : Optional[Dict[str, Any]], optional
A dictionary of configurations for TVM's pass context. These can fine-tune
the compilation process. Examples include "tir.disable_vectorize"
(default: None).
debug_root_path : Optional[str], optional
If provided, the compiled kernel's source code will be saved to a file
in this directory. This is useful for debugging the generated code.
If None, no debug information is saved (default: None).
If a relative path is given, it's made absolute relative to the project root
or current working directory.
return_program : bool, optional
If True, the decorated function will return a tuple containing the
original program's result and the compiled kernel. If False, only the
compiled kernel is returned (default: False).
"""
if debug_root_path is None:
# This logic was previously under 'if debug and debug_root_path is None:'
# Now, if debug_root_path is explicitly None, we don't try to set a default path.
# If a user wants debugging, they must provide a path.
pass
elif not path.isabs(debug_root_path): # If a relative path is given, make it absolute
try:
# This assumes the file is part of a typical project structure
base_path = path.dirname(path.dirname(path.dirname(__file__)))
debug_root_path = path.join(base_path, debug_root_path)
except NameError: # __file__ is not defined (e.g., in a REPL or notebook)
# Fallback to making it absolute based on current working directory if __file__ fails
debug_root_path = path.abspath(debug_root_path)
self.out_idx = out_idx
self.execution_backend = execution_backend
self.target = target
self.target_host = target_host
self.verbose = verbose
self.pass_configs = pass_configs
self.debug_root_path: Optional[str] = debug_root_path
self.return_program: bool = return_program
# Type hint the caches
self._program_cache: Dict[tuple, _RProg] = {}
self._kernel_cache: Dict[tuple, Kernel] = {}
# Overload __call__ based on the value of self.return_program
# This tells the type checker what the *wrapper* function will return.
# The wrapper will take the same parameters P as the original function.
# Case 1: return_program is True
@overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]:
# This signature is chosen by the type checker if self.return_program is True
# (inferred from the __init__ call).
...
# Case 2: return_program is False (or not specified, defaulting to False)
@overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]:
# This signature is chosen if self.return_program is False.
...
# Actual implementation of __call__
def __call__(
self, func: Union[Callable[_P, _RProg], PrimFunc]
) -> Callable[_P, Any]: # Any for implementation flexibility
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: # Use _P.args and _P.kwargs
# Create a hashable key. args is already a tuple.
# For kwargs, convert to a sorted tuple of items to ensure consistent ordering.
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
# Check if both program and kernel are cached.
# If program is not cached, we'll recompute both.
# (The original check 'key not in self._program_cache or key not in self._kernel_cache'
# implies that if either is missing, both are recomputed and stored.
# A simpler 'key not in self._program_cache' would often suffice if they are always
# added together.)
if key not in self._program_cache: # Assuming if program isn't there, kernel isn't either or needs refresh
if isinstance(func, PrimFunc):
program_result = func
elif isinstance(func, Callable):
program_result = func(*args, **kwargs)
else:
raise ValueError(f"Invalid function type: {type(func)}")
kernel_result = compile(
program_result,
out_idx=self.out_idx,
execution_backend=self.execution_backend,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs,
)
if self.debug_root_path: # Check if a path is provided
func_name = func.__name__
kernel_file = f'tilelang_jit_kernel_{func_name}.c'
# Ensure the debug directory exists
makedirs(self.debug_root_path, exist_ok=True)
with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
print(kernel_result.get_kernel_source(), file=f)
self._program_cache[key] = program_result
self._kernel_cache[key] = kernel_result
# Retrieve from cache (even if just populated)
cached_program = self._program_cache[key]
cached_kernel = self._kernel_cache[key]
if self.return_program:
return cached_program, cached_kernel
else:
return cached_kernel
return wrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment