""" This module provides an auto-tuning infrastructure for TileLang (tl) programs. It includes functionality to JIT-compile TileLang programs into a runnable kernel adapter using TVM. """ from typing import ( Any, List, Union, Callable, Tuple, overload, Literal, Dict, # For type hinting dicts Optional, ) from tilelang import tvm as tvm from tvm.tir import PrimFunc from tvm.target import Target from tilelang.jit.kernel import JITKernel from tilelang.utils.target import determine_target from tilelang.cache import cached from os import path, makedirs from logging import getLogger import functools from tilelang.jit.param import Kernel, _P, _RProg logger = getLogger(__name__) def compile( func: PrimFunc = None, out_idx: Union[List[int], int, None] = None, execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", target: Union[str, Target] = "auto", target_host: Union[str, Target, None] = None, verbose: bool = False, pass_configs: Optional[Dict[str, Any]] = None, compile_flags: Optional[Union[List[str], str]] = None, ) -> JITKernel: """ Compile the given TileLang PrimFunc with TVM and build a JITKernel. Parameters ---------- func : tvm.tir.PrimFunc, optional The TileLang TIR function to compile and wrap. out_idx : Union[List[int], int], optional Index(es) of the output tensors to return (default: None). execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional Execution backend to use for kernel execution (default: "cython"). target : Union[str, Target], optional Compilation target, either as a string or a TVM Target object (default: "auto"). target_host : Union[str, Target], optional Target host for cross-compilation (default: None). verbose : bool, optional Whether to enable verbose output (default: False). pass_configs : dict, optional Additional keyword arguments to pass to the Compiler PassContext. Available options: "tir.disable_vectorize": bool, default: False "tl.disable_tma_lower": bool, default: False "tl.disable_warp_specialized": bool, default: False "tl.config_index_bitwidth": int, default: None "tl.disable_dynamic_tail_split": bool, default: False "tl.dynamic_vectorize_size_bits": int, default: 128 "tl.disable_safe_memory_legalize": bool, default: False """ assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}" if isinstance(compile_flags, str): compile_flags = [compile_flags] # This path is not a performance critical path, so we can afford to convert the target. target = Target(determine_target(target)) return cached( func=func, out_idx=out_idx, execution_backend=execution_backend, target=target, target_host=target_host, verbose=verbose, pass_configs=pass_configs, compile_flags=compile_flags, ) class _JitImplementation: out_idx: Optional[Union[List[int], int]] target: Union[str, Target] target_host: Union[str, Target] execution_backend: Literal["dlpack", "ctypes", "cython"] verbose: bool pass_configs: Optional[Dict[str, Any]] debug_root_path: Optional[str] compile_flags: Optional[Union[List[str], str]] def __init__(self, out_idx: Any = None, target: Union[str, Target] = "auto", target_host: Union[str, Target] = None, execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", verbose: bool = False, pass_configs: Optional[Dict[str, Any]] = None, debug_root_path: Optional[str] = None, compile_flags: Optional[Union[List[str], str]] = None): """ Initializes the JIT compiler decorator. Parameters ---------- out_idx : Any, optional Index(es) of the output tensors to return from the compiled kernel (default: None, meaning all outputs are returned or determined by the kernel itself). target : Union[str, Target], optional Compilation target for TVM. Can be a string (e.g., "cuda", "llvm") or a TVM Target object. If "auto", the target is determined automatically (default: "auto"). target_host : Union[str, Target], optional Target host for cross-compilation, similar to `target` (default: None). execution_backend : Literal["dlpack", "ctypes", "cython"], optional The backend used for kernel execution and argument passing. "dlpack" is generally preferred for zero-copy tensor passing with compatible frameworks. "ctypes" uses standard C types. "cython" uses Cython for potentially faster execution. (default: "cython"). verbose : bool, optional If True, enables verbose logging during compilation (default: False). pass_configs : Optional[Dict[str, Any]], optional A dictionary of configurations for TVM's pass context. These can fine-tune the compilation process. Examples include "tir.disable_vectorize" (default: None). debug_root_path : Optional[str], optional If provided, the compiled kernel's source code will be saved to a file in this directory. This is useful for debugging the generated code. If None, no debug information is saved (default: None). If a relative path is given, it's made absolute relative to the project root or current working directory. compile_flags : Optional[Union[List[str], str]], optional Additional compilation flags to pass to the compiler. If None, no additional compilation flags are passed (default: None). """ self.out_idx = out_idx self.execution_backend = execution_backend self.target = target self.target_host = target_host self.verbose = verbose self.pass_configs = pass_configs self.compile_flags = compile_flags # Corrected debug_root_path handling self.debug_root_path = debug_root_path if self.debug_root_path is not None and not path.isabs(self.debug_root_path): try: base_path = path.dirname(path.dirname(path.dirname(__file__))) self.debug_root_path = path.join(base_path, self.debug_root_path) except NameError: self.debug_root_path = path.abspath(self.debug_root_path) self._kernel_cache: Dict[tuple, Kernel] = {} # This tells the type checker what the *wrapper* function will return. # this is for linting, please do not remove it. @overload def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]: ... @overload def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]: ... # Actual implementation of __call__ def __call__( self, func: Callable[_P, _RProg] # func is Union[Callable[_P, _RProg], PrimFunc] in original ) -> Callable[_P, Any]: @functools.wraps(func) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: # Separate out the tuning parameters from the user's kwargs tune_params = kwargs.pop('__tune_params', {}) # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache return_compile_arguments = kwargs.pop('__return_compile_arguments', False) if return_compile_arguments: compile_args = { 'out_idx': self.out_idx, 'execution_backend': self.execution_backend, 'target': self.target, 'target_host': self.target_host, 'verbose': self.verbose, 'pass_configs': self.pass_configs, 'compile_flags': self.compile_flags, } return compile_args key_args_tuple = args key_kwargs_tuple = tuple(sorted(kwargs.items())) tuned_key_kwargs_tuple = tuple(sorted(tune_params.items())) key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple) if key not in self._kernel_cache: # Ensure 'func' (the original user function) is used correctly program_result_source = func if isinstance(program_result_source, PrimFunc): program_result = program_result_source elif callable(program_result_source): program_result = program_result_source(*args, **kwargs, **tune_params) else: raise ValueError(f"Invalid function type: {type(program_result_source)}") kernel_result = compile( program_result, out_idx=self.out_idx, execution_backend=self.execution_backend, target=self.target, target_host=self.target_host, verbose=self.verbose, pass_configs=self.pass_configs, compile_flags=self.compile_flags, ) if self.debug_root_path: func_name = getattr(func, '__name__', 'jit_kernel') # Use func for name kernel_file = f'tilelang_jit_kernel_{func_name}.c' program_file = f'tilelang_jit_program_{func_name}.py' makedirs(self.debug_root_path, exist_ok=True) with open(path.join(self.debug_root_path, kernel_file), 'w') as f: print(kernel_result.get_kernel_source(), file=f) with open(path.join(self.debug_root_path, program_file), 'w') as f: print(program_result.script(), file=f) self._kernel_cache[key] = kernel_result return self._kernel_cache[key] return wrapper def jit( # This is the new public interface func: Union[Callable[_P, _RProg], PrimFunc, None] = None, *, # Indicates subsequent arguments are keyword-only out_idx: Any = None, target: Union[str, Target] = "auto", target_host: Union[str, Target] = None, execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", verbose: bool = False, pass_configs: Optional[Dict[str, Any]] = None, debug_root_path: Optional[str] = None, compile_flags: Optional[Union[List[str], str]] = None): """ Just-In-Time (JIT) compiler decorator for TileLang functions. This decorator can be used without arguments (e.g., `@tilelang.jit`): Applies JIT compilation with default settings. Parameters ---------- func_or_out_idx : Any, optional If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter. If using `@tilelang.jit` directly on a function, this argument is implicitly the function to be decorated (and `out_idx` will be `None`). target : Union[str, Target], optional Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". target_host : Union[str, Target], optional Target host for cross-compilation. Defaults to None. execution_backend : Literal["dlpack", "ctypes", "cython"], optional Backend for kernel execution and argument passing. Defaults to "cython". verbose : bool, optional Enables verbose logging during compilation. Defaults to False. pass_configs : Optional[Dict[str, Any]], optional Configurations for TVM's pass context. Defaults to None. debug_root_path : Optional[str], optional Directory to save compiled kernel source for debugging. Defaults to None. Returns ------- Callable Either a JIT-compiled wrapper around the input function, or a configured decorator instance that can then be applied to a function. """ if isinstance(compile_flags, str): compile_flags = [compile_flags] if callable(func): # Case 1: Used as @jit (func_or_out_idx is the function, others are defaults) # Create a default _JitImplementation instance and apply it to the function. default_decorator = _JitImplementation( out_idx=out_idx, # Explicitly None for the default case target=target, target_host=target_host, execution_backend=execution_backend, verbose=verbose, pass_configs=pass_configs, debug_root_path=debug_root_path, compile_flags=compile_flags) return default_decorator(func) elif isinstance(func, PrimFunc): raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") else: # Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx. # Create a _JitImplementation instance with the provided/defaulted arguments. # This instance is a decorator that will be applied to the function later. configured_decorator = _JitImplementation( out_idx=out_idx, # Pass along; could be an actual out_idx or None target=target, target_host=target_host, execution_backend=execution_backend, verbose=verbose, pass_configs=pass_configs, debug_root_path=debug_root_path, compile_flags=compile_flags) return configured_decorator