Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -3,6 +3,7 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs. ...@@ -3,6 +3,7 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM. kernel adapter using TVM.
""" """
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
...@@ -39,17 +40,16 @@ from tqdm.auto import tqdm ...@@ -39,17 +40,16 @@ from tqdm.auto import tqdm
logger = getLogger(__name__) logger = getLogger(__name__)
_P = ParamSpec('_P') _P = ParamSpec("_P")
_KP = ParamSpec('_KP') _KP = ParamSpec("_KP")
_T = TypeVar('_T') _T = TypeVar("_T")
_Ret = TypeVar('_Ret') _Ret = TypeVar("_Ret")
def compile( def compile(
func: PrimFunc[_KP, _T] = None, func: PrimFunc[_KP, _T] = None,
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
"torch"] = "auto",
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target | None = None, target_host: str | Target | None = None,
verbose: bool = False, verbose: bool = False,
...@@ -83,11 +83,9 @@ def compile( ...@@ -83,11 +83,9 @@ def compile(
if isinstance(compile_flags, str): if isinstance(compile_flags, str):
compile_flags = [compile_flags] compile_flags = [compile_flags]
if hasattr(func, 'out_idx_override'): if hasattr(func, "out_idx_override"):
if func.out_idx_override is not None and out_idx is not None: if func.out_idx_override is not None and out_idx is not None:
raise ValueError( raise ValueError("Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors")
"Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors"
)
out_idx = func.out_idx_override or out_idx out_idx = func.out_idx_override or out_idx
# This path is not a performance critical path, so we can afford to convert the target. # This path is not a performance critical path, so we can afford to convert the target.
...@@ -96,6 +94,7 @@ def compile( ...@@ -96,6 +94,7 @@ def compile(
# Resolve execution backend (handles aliases, auto, validation per target) # Resolve execution backend (handles aliases, auto, validation per target)
requested_backend = execution_backend requested_backend = execution_backend
from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target
execution_backend = resolve_execution_backend(requested_backend, target) execution_backend = resolve_execution_backend(requested_backend, target)
if verbose: if verbose:
allowed_now = allowed_backends_for_target(target, include_unavailable=False) allowed_now = allowed_backends_for_target(target, include_unavailable=False)
...@@ -119,17 +118,18 @@ def compile( ...@@ -119,17 +118,18 @@ def compile(
) )
def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], def par_compile(
out_idx: list[int] | int | None = None, funcs: Iterable[PrimFunc[_KP, _T]],
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", out_idx: list[int] | int | None = None,
"torch"] = "auto", execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target | None = None, target_host: str | Target | None = None,
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
num_workers: int = None, num_workers: int = None,
ignore_error: bool = False) -> list[JITKernel[_KP, _T]]: ignore_error: bool = False,
) -> list[JITKernel[_KP, _T]]:
""" """
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters Parameters
...@@ -151,7 +151,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], ...@@ -151,7 +151,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
Additional keyword arguments to pass to the Compiler PassContext. Additional keyword arguments to pass to the Compiler PassContext.
Refer to `tilelang.transform.PassConfigKey` for supported options. Refer to `tilelang.transform.PassConfigKey` for supported options.
""" """
with concurrent.futures.ThreadPoolExecutor(num_workers, 'tl-par-comp') as executor: with concurrent.futures.ThreadPoolExecutor(num_workers, "tl-par-comp") as executor:
futures = [] futures = []
future_map = {} future_map = {}
for i, func in enumerate(funcs): for i, func in enumerate(funcs):
...@@ -170,9 +170,9 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], ...@@ -170,9 +170,9 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
futures.append(future) futures.append(future)
results = [... for _ in futures] results = [... for _ in futures]
for future in tqdm( for future in tqdm(
concurrent.futures.as_completed(futures), concurrent.futures.as_completed(futures),
total=len(futures), total=len(futures),
desc="Parallel Compiling", desc="Parallel Compiling",
): ):
idx = future_map[future] idx = future_map[future]
if ignore_error: if ignore_error:
...@@ -189,7 +189,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], ...@@ -189,7 +189,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
@dataclass @dataclass
class JITImpl(Generic[_P, _KP, _T, _Ret]): class JITImpl(Generic[_P, _KP, _T, _Ret]):
''' """
Detailed Just-In-Time wrapper for TileLang programs. Detailed Just-In-Time wrapper for TileLang programs.
This dataclass encapsulates the configuration and runtime helpers used by the This dataclass encapsulates the configuration and runtime helpers used by the
...@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): ...@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
PrimFunc and the resulting set is compiled in parallel via the PrimFunc and the resulting set is compiled in parallel via the
module-level `par_compile` helper. Returns a list of JITKernel objects module-level `par_compile` helper. Returns a list of JITKernel objects
in the same order as the provided configs. in the same order as the provided configs.
''' """
out_idx: list[int] | int | None out_idx: list[int] | int | None
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"]
...@@ -302,10 +302,9 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): ...@@ -302,10 +302,9 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}" assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}"
return tir return tir
def par_compile(self, def par_compile(
configs: Iterable[dict[str, Any] | tuple[str, Any]], self, configs: Iterable[dict[str, Any] | tuple[str, Any]], num_workers: int = None, ignore_error: bool = False
num_workers: int = None, ) -> list[JITKernel[_KP, _T]]:
ignore_error: bool = False) -> list[JITKernel[_KP, _T]]:
""" """
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters Parameters
...@@ -328,7 +327,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): ...@@ -328,7 +327,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
""" """
configs = list(configs) configs = list(configs)
funcs = [] funcs = []
for cfg in tqdm(configs, desc='Elaborating'): for cfg in tqdm(configs, desc="Elaborating"):
if isinstance(cfg, tuple): if isinstance(cfg, tuple):
funcs.append(self.get_tir(*cfg)) funcs.append(self.get_tir(*cfg))
elif isinstance(cfg, dict): elif isinstance(cfg, dict):
...@@ -345,7 +344,8 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): ...@@ -345,7 +344,8 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
pass_configs=self.pass_configs, pass_configs=self.pass_configs,
compile_flags=self.compile_flags, compile_flags=self.compile_flags,
num_workers=num_workers, num_workers=num_workers,
ignore_error=ignore_error) ignore_error=ignore_error,
)
def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
func = self.get_tir(*args, **kwargs) func = self.get_tir(*args, **kwargs)
...@@ -362,25 +362,25 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): ...@@ -362,25 +362,25 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
if self.debug_root_path: if self.debug_root_path:
if isinstance(self.func, PrimFunc): if isinstance(self.func, PrimFunc):
func_name = self.func.attrs['global_symbol'] func_name = self.func.attrs["global_symbol"]
else: else:
func_name = getattr(self.func, '__name__', 'jit_kernel') func_name = getattr(self.func, "__name__", "jit_kernel")
kernel_file = f'tilelang_jit_kernel_{func_name}.c' kernel_file = f"tilelang_jit_kernel_{func_name}.c"
program_file = f'tilelang_jit_program_{func_name}.py' program_file = f"tilelang_jit_program_{func_name}.py"
makedirs(self.debug_root_path, exist_ok=True) makedirs(self.debug_root_path, exist_ok=True)
with open(path.join(self.debug_root_path, kernel_file), 'w') as f: with open(path.join(self.debug_root_path, kernel_file), "w") as f:
print(kernel_result.get_kernel_source(), file=f) print(kernel_result.get_kernel_source(), file=f)
with open(path.join(self.debug_root_path, program_file), 'w') as f: with open(path.join(self.debug_root_path, program_file), "w") as f:
print(func.script(), file=f) print(func.script(), file=f)
return kernel_result return kernel_result
def parse_cache_key(self, *args: _P.args, **kwargs: _P.kwargs): def parse_cache_key(self, *args: _P.args, **kwargs: _P.kwargs):
if isinstance(self.func, PrimFuncCreater): if isinstance(self.func, PrimFuncCreater):
tune_params = kwargs.pop('__tune_params', {}) tune_params = kwargs.pop("__tune_params", {})
return self.func.func_annot.parse_key(*args, **kwargs, **tune_params) return self.func.func_annot.parse_key(*args, **kwargs, **tune_params)
else: else:
tune_params = kwargs.pop('__tune_params', {}) tune_params = kwargs.pop("__tune_params", {})
key_args_tuple = args key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items())) key_kwargs_tuple = tuple(sorted(kwargs.items()))
tuned_key_kwargs_tuple = tuple(sorted(tune_params.items())) tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
...@@ -389,34 +389,31 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): ...@@ -389,34 +389,31 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
def convert_kernel_args(self, *args: _P.args, **kwargs: _P.kwargs): def convert_kernel_args(self, *args: _P.args, **kwargs: _P.kwargs):
if isinstance(self.func, PrimFuncCreater): if isinstance(self.func, PrimFuncCreater):
tune_params = kwargs.pop('__tune_params', {}) tune_params = kwargs.pop("__tune_params", {})
return self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params) return self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params)
else: else:
raise NotImplementedError( raise NotImplementedError("convert_arg_to_kernel_args is only implemented for PrimFuncCreater.")
"convert_arg_to_kernel_args is only implemented for PrimFuncCreater.")
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
# Separate out the tuning parameters from the user's kwargs # Separate out the tuning parameters from the user's kwargs
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
return_compile_arguments = kwargs.pop('__return_compile_arguments', False) return_compile_arguments = kwargs.pop("__return_compile_arguments", False)
if return_compile_arguments: if return_compile_arguments:
logger.warning( logger.warning("`__return_compile_arguments` is deprecated and will be removed in future versions.")
"`__return_compile_arguments` is deprecated and will be removed in future versions."
)
compile_args = { compile_args = {
'out_idx': self.out_idx, "out_idx": self.out_idx,
'execution_backend': self.execution_backend, "execution_backend": self.execution_backend,
'target': self.target, "target": self.target,
'target_host': self.target_host, "target_host": self.target_host,
'verbose': self.verbose, "verbose": self.verbose,
'pass_configs': self.pass_configs, "pass_configs": self.pass_configs,
'compile_flags': self.compile_flags, "compile_flags": self.compile_flags,
} }
return compile_args return compile_args
key = self.parse_cache_key(*args, **kwargs) key = self.parse_cache_key(*args, **kwargs)
tune_params = kwargs.pop('__tune_params', {}) tune_params = kwargs.pop("__tune_params", {})
kernel = self._kernel_cache.get(key, None) kernel = self._kernel_cache.get(key, None)
if kernel is None: if kernel is None:
...@@ -434,8 +431,7 @@ ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvr ...@@ -434,8 +431,7 @@ ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvr
@overload @overload
def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]: def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]: ...
...
@overload @overload
...@@ -448,22 +444,22 @@ def jit( ...@@ -448,22 +444,22 @@ def jit(
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None, debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None compile_flags: list[str] | str | None = None,
) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]: ) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]: ...
...
def jit( # This is the new public interface def jit( # This is the new public interface
func: Callable[_P, _T] | PrimFunc | None = None, func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only *, # Indicates subsequent arguments are keyword-only
out_idx: Any = None, out_idx: Any = None,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
execution_backend: ExecutionBackend = "auto", execution_backend: ExecutionBackend = "auto",
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None, debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None): compile_flags: list[str] | str | None = None,
):
""" """
Just-In-Time (JIT) compiler decorator for TileLang functions. Just-In-Time (JIT) compiler decorator for TileLang functions.
...@@ -516,7 +512,8 @@ def jit( # This is the new public interface ...@@ -516,7 +512,8 @@ def jit( # This is the new public interface
compile_flags=compile_flags, compile_flags=compile_flags,
func_source=inspect.getsource(orig_func), func_source=inspect.getsource(orig_func),
signature=inspect.signature(orig_func), signature=inspect.signature(orig_func),
lazy_jit=False) lazy_jit=False,
)
if func is not None: if func is not None:
return decorator(func) return decorator(func)
...@@ -525,8 +522,7 @@ def jit( # This is the new public interface ...@@ -525,8 +522,7 @@ def jit( # This is the new public interface
@overload @overload
def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: ...
...
@overload @overload
...@@ -539,9 +535,8 @@ def lazy_jit( ...@@ -539,9 +535,8 @@ def lazy_jit(
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None, debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None compile_flags: list[str] | str | None = None,
) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: ) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: ...
...
def lazy_jit( def lazy_jit(
...@@ -555,7 +550,6 @@ def lazy_jit( ...@@ -555,7 +550,6 @@ def lazy_jit(
debug_root_path: str | None = None, debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
): ):
if isinstance(compile_flags, str): if isinstance(compile_flags, str):
compile_flags = [compile_flags] compile_flags = [compile_flags]
...@@ -567,7 +561,8 @@ def lazy_jit( ...@@ -567,7 +561,8 @@ def lazy_jit(
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
debug_root_path=debug_root_path, debug_root_path=debug_root_path,
compile_flags=compile_flags) compile_flags=compile_flags,
)
def decorator(func: Callable[_P, _T]): def decorator(func: Callable[_P, _T]):
pf: PrimFunc[_P, _T] | PrimFuncCreater[_P, _T] = prim_func(func, generator=True) pf: PrimFunc[_P, _T] | PrimFuncCreater[_P, _T] = prim_func(func, generator=True)
...@@ -576,10 +571,7 @@ def lazy_jit( ...@@ -576,10 +571,7 @@ def lazy_jit(
# return compile(pf, **compile_args) # return compile(pf, **compile_args)
# else: # else:
return JITImpl( return JITImpl(
func=pf, func=pf, **compile_args, func_source=inspect.getsource(pf.orig_func), signature=inspect.signature(pf.orig_func), lazy_jit=True
**compile_args, )
func_source=inspect.getsource(pf.orig_func),
signature=inspect.signature(pf.orig_func),
lazy_jit=True)
return decorator(func) if func is not None else decorator return decorator(func) if func is not None else decorator
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -8,7 +9,6 @@ import torch ...@@ -8,7 +9,6 @@ import torch
class BaseKernelAdapter(ABC): class BaseKernelAdapter(ABC):
func: Callable | None = None func: Callable | None = None
def __init__(self, mod, params: list[KernelParam], result_idx: list[int]) -> None: def __init__(self, mod, params: list[KernelParam], result_idx: list[int]) -> None:
...@@ -24,18 +24,14 @@ class BaseKernelAdapter(ABC): ...@@ -24,18 +24,14 @@ class BaseKernelAdapter(ABC):
result_idx = [] result_idx = []
elif isinstance(result_idx, int): elif isinstance(result_idx, int):
if result_idx > len(params) or result_idx < -len(params): if result_idx > len(params) or result_idx < -len(params):
raise ValueError( raise ValueError(f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}")
f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}"
)
if result_idx < 0: if result_idx < 0:
result_idx = len(params) + result_idx result_idx = len(params) + result_idx
result_idx = [result_idx] result_idx = [result_idx]
elif isinstance(result_idx, list): elif isinstance(result_idx, list):
for i, idx in enumerate(result_idx): for i, idx in enumerate(result_idx):
if idx >= len(params) or idx < -len(params): if idx >= len(params) or idx < -len(params):
raise ValueError( raise ValueError(f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}")
f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}"
)
if idx < 0: if idx < 0:
result_idx[i] = len(params) + idx result_idx[i] = len(params) + idx
else: else:
......
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from __future__ import annotations from __future__ import annotations
import torch import torch
from ..base import BaseKernelAdapter from ..base import BaseKernelAdapter
...@@ -41,18 +42,20 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -41,18 +42,20 @@ class CtypesKernelAdapter(BaseKernelAdapter):
param_dtypes: list[torch.dtype] | None = None # Cache for parameter dtypes param_dtypes: list[torch.dtype] | None = None # Cache for parameter dtypes
param_shapes: list[list] | None = None # Cache for parameter shapes param_shapes: list[list] | None = None # Cache for parameter shapes
def __init__(self, def __init__(
params: list[TensorType], self,
result_idx: list[int], params: list[TensorType],
target: str, result_idx: list[int],
func_or_mod: tir.PrimFunc | tvm.IRModule, target: str,
host_mod: tvm.IRModule | None = None, func_or_mod: tir.PrimFunc | tvm.IRModule,
device_mod: tvm.IRModule | None = None, host_mod: tvm.IRModule | None = None,
host_kernel_source: str | None = None, device_mod: tvm.IRModule | None = None,
device_kernel_source: str | None = None, host_kernel_source: str | None = None,
verbose: bool = False, device_kernel_source: str | None = None,
pass_configs: dict[str, Any] | None = None, verbose: bool = False,
compile_flags: list[str] | None = None): pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
...@@ -109,17 +112,19 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -109,17 +112,19 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self._post_init() self._post_init()
@classmethod @classmethod
def from_database(cls, def from_database(
params: list[TensorType], cls,
result_idx: list[int], params: list[TensorType],
target: str, result_idx: list[int],
func_or_mod: tir.PrimFunc | tvm.IRModule, target: str,
host_kernel_source: str, func_or_mod: tir.PrimFunc | tvm.IRModule,
device_kernel_source: str, host_kernel_source: str,
kernel_lib_path: str, device_kernel_source: str,
verbose: bool = False, kernel_lib_path: str,
pass_configs: dict[str, Any] | None = None, verbose: bool = False,
compile_flags: list[str] | None = None): pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
...@@ -175,15 +180,13 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -175,15 +180,13 @@ class CtypesKernelAdapter(BaseKernelAdapter):
if param in buffer_map: if param in buffer_map:
buffer = buffer_map[param] buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape): for j, shape in enumerate(buffer.shape):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params):
(shape not in params)):
dynamic_symbolic_map[shape] = (0, i, j) dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params): for i, param in enumerate(params):
if param in buffer_map: if param in buffer_map:
buffer = buffer_map[param] buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides): for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params):
(stride not in params)):
dynamic_symbolic_map[stride] = (1, i, j) dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
...@@ -192,9 +195,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -192,9 +195,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
Converts PyTorch tensor pointers to C void pointers for ctypes interface. Converts PyTorch tensor pointers to C void pointers for ctypes interface.
""" """
ctypes_args = [ ctypes_args = [ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args]
ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args
]
ctypes_args.append(ctypes.c_void_p(stream)) ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args) self.lib.call(*ctypes_args)
...@@ -288,7 +289,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -288,7 +289,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
@property @property
def is_dynamic(self): def is_dynamic(self):
"""Indicates whether the kernel handles dynamic shapes.""" """Indicates whether the kernel handles dynamic shapes."""
return (self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0) return self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0
def get_kernel_source(self, kernel_only: bool = False): def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel.""" """Returns the source code of the compiled kernel."""
......
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from __future__ import annotations from __future__ import annotations
import ctypes import ctypes
import logging import logging
...@@ -70,17 +71,19 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -70,17 +71,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
# Pass configs for the compiler # Pass configs for the compiler
pass_configs: dict[str, Any] | None = None pass_configs: dict[str, Any] | None = None
def __init__(self, def __init__(
params: list[KernelParam], self,
result_idx: list[int], params: list[KernelParam],
target: str | Target, result_idx: list[int],
func_or_mod: tir.PrimFunc | tvm.IRModule, target: str | Target,
host_mod: tvm.IRModule | None = None, func_or_mod: tir.PrimFunc | tvm.IRModule,
device_mod: tvm.IRModule | None = None, host_mod: tvm.IRModule | None = None,
device_kernel_source: str | None = None, device_mod: tvm.IRModule | None = None,
verbose: bool = False, device_kernel_source: str | None = None,
pass_configs: dict[str, Any] | None = None, verbose: bool = False,
compile_flags: list[str] | None = None): pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
...@@ -130,7 +133,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -130,7 +133,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.lib.get_last_error.restype = ctypes.c_char_p self.lib.get_last_error.restype = ctypes.c_char_p
result = self.lib.init() result = self.lib.init()
if result != 0: if result != 0:
error_msg = self.lib.get_last_error().decode('utf-8') error_msg = self.lib.get_last_error().decode("utf-8")
error_msg += f"\n{self.lib_code}" error_msg += f"\n{self.lib_code}"
raise RuntimeError(f"Initialization failed: {error_msg}") raise RuntimeError(f"Initialization failed: {error_msg}")
...@@ -145,17 +148,19 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -145,17 +148,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
self._post_init() self._post_init()
@classmethod @classmethod
def from_database(cls, def from_database(
params: list[TensorType], cls,
result_idx: list[int], params: list[TensorType],
target: str, result_idx: list[int],
func_or_mod: tir.PrimFunc | tvm.IRModule, target: str,
host_kernel_source: str, func_or_mod: tir.PrimFunc | tvm.IRModule,
device_kernel_source: str, host_kernel_source: str,
kernel_lib_path: str, device_kernel_source: str,
verbose: bool = False, kernel_lib_path: str,
pass_configs: dict[str, Any] | None = None, verbose: bool = False,
compile_flags: list[str] | None = None): pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
...@@ -190,11 +195,10 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -190,11 +195,10 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.lib.get_last_error.restype = ctypes.c_char_p adapter.lib.get_last_error.restype = ctypes.c_char_p
result = adapter.lib.init() result = adapter.lib.init()
if result != 0: if result != 0:
error_msg = adapter.lib.get_last_error().decode('utf-8') error_msg = adapter.lib.get_last_error().decode("utf-8")
raise RuntimeError(f"Initialization failed: {error_msg}") raise RuntimeError(f"Initialization failed: {error_msg}")
adapter.cython_wrapper = CythonKernelWrapper(adapter.result_idx, adapter.params, adapter.cython_wrapper = CythonKernelWrapper(adapter.result_idx, adapter.params, adapter.lib)
adapter.lib)
adapter.cython_wrapper.set_dynamic_symbolic_map(adapter.dynamic_symbolic_map) adapter.cython_wrapper.set_dynamic_symbolic_map(adapter.dynamic_symbolic_map)
adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map) adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map)
adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map) adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map)
...@@ -221,15 +225,13 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -221,15 +225,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
if param in buffer_map: if param in buffer_map:
buffer = buffer_map[param] buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape): for j, shape in enumerate(buffer.shape):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params):
(shape not in params)):
dynamic_symbolic_map[shape] = (0, i, j) dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params): for i, param in enumerate(params):
if param in buffer_map: if param in buffer_map:
buffer = buffer_map[param] buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides): for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params):
(stride not in params)):
dynamic_symbolic_map[stride] = (1, i, j) dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
...@@ -259,14 +261,13 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -259,14 +261,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
params = func.params params = func.params
ptr_map = {} ptr_map = {}
for i, param in enumerate(params): for i, param in enumerate(params):
if param.dtype == 'handle': if param.dtype == "handle":
ptr_map[i] = param.name ptr_map[i] = param.name
return ptr_map return ptr_map
def _process_static_buffer_infos(self) -> \ def _process_static_buffer_infos(
tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]], self,
dict[tir.Var, tuple[int, list[tuple[int, int]]]], ) -> tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]], dict[tir.Var, tuple[int, list[tuple[int, int]]]], list[tuple[tir.Var]]]:
list[tuple[tir.Var]]]:
"""Extract information about static shapes from the TIR function. """Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes. Maps buffer variables to their corresponding static shapes.
...@@ -332,9 +333,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -332,9 +333,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
Converts PyTorch tensor pointers to C void pointers for ctypes interface. Converts PyTorch tensor pointers to C void pointers for ctypes interface.
""" """
ctypes_args = [ ctypes_args = [ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args]
ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args
]
ctypes_args.append(ctypes.c_void_p(stream)) ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args) self.lib.call(*ctypes_args)
...@@ -349,9 +348,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -349,9 +348,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
skip_tensor_validation: Whether to skip tensor attributes validation which skip_tensor_validation: Whether to skip tensor attributes validation which
includes shape, dtype, device, etc. includes shape, dtype, device, etc.
""" """
return self.cython_wrapper.forward([*args], return self.cython_wrapper.forward([*args], stream=stream, skip_tensor_validation=skip_tensor_validation)
stream=stream,
skip_tensor_validation=skip_tensor_validation)
return lambda_forward return lambda_forward
......
...@@ -55,6 +55,7 @@ class LibraryGenerator: ...@@ -55,6 +55,7 @@ class LibraryGenerator:
verbose = self.verbose verbose = self.verbose
if is_cuda_target(target): if is_cuda_target(target):
from tilelang.env import CUTLASS_INCLUDE_DIR from tilelang.env import CUTLASS_INCLUDE_DIR
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) # noqa: SIM115 src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) # noqa: SIM115
target_arch = get_target_arch(get_target_compute_version(target)) target_arch = get_target_arch(get_target_compute_version(target))
libpath = src.name.replace(".cu", ".so") libpath = src.name.replace(".cu", ".so")
...@@ -65,15 +66,12 @@ class LibraryGenerator: ...@@ -65,15 +66,12 @@ class LibraryGenerator:
"TL_ENABLE_FAST_MATH", "TL_ENABLE_FAST_MATH",
"0.1.7", "0.1.7",
) )
enable_fast_math = not self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, enable_fast_math = not self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, True)
True)
else: else:
enable_fast_math = self.pass_configs.get(PassConfigKey.TL_ENABLE_FAST_MATH, False) enable_fast_math = self.pass_configs.get(PassConfigKey.TL_ENABLE_FAST_MATH, False)
ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, None)
None) verbose_ptxas_output = self.pass_configs.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False)
verbose_ptxas_output = self.pass_configs.get(
PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False)
command = [ command = [
get_nvcc_compiler(), get_nvcc_compiler(),
...@@ -102,6 +100,7 @@ class LibraryGenerator: ...@@ -102,6 +100,7 @@ class LibraryGenerator:
elif is_hip_target(target): elif is_hip_target(target):
from tilelang.env import COMPOSABLE_KERNEL_INCLUDE_DIR from tilelang.env import COMPOSABLE_KERNEL_INCLUDE_DIR
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115
libpath = src.name.replace(".cpp", ".so") libpath = src.name.replace(".cpp", ".so")
rocm_path = find_rocm_path() rocm_path = find_rocm_path()
...@@ -119,6 +118,7 @@ class LibraryGenerator: ...@@ -119,6 +118,7 @@ class LibraryGenerator:
] ]
elif is_cpu_target(target): elif is_cpu_target(target):
from tilelang.contrib.cc import get_cplus_compiler from tilelang.contrib.cc import get_cplus_compiler
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115
libpath = src.name.replace(".cpp", ".so") libpath = src.name.replace(".cpp", ".so")
...@@ -134,9 +134,7 @@ class LibraryGenerator: ...@@ -134,9 +134,7 @@ class LibraryGenerator:
] ]
if self.compile_flags: if self.compile_flags:
command += [ command += [item for flag in self.compile_flags for item in flag.split() if item not in command]
item for flag in self.compile_flags for item in flag.split() if item not in command
]
command += ["-o", libpath] command += ["-o", libpath]
...@@ -151,8 +149,7 @@ class LibraryGenerator: ...@@ -151,8 +149,7 @@ class LibraryGenerator:
raise RuntimeError(f"Compile kernel failed because of {e}") from e raise RuntimeError(f"Compile kernel failed because of {e}") from e
if ret.returncode != 0: if ret.returncode != 0:
raise RuntimeError(f"Compilation Failed! {command}" raise RuntimeError(f"Compilation Failed! {command}\n {self.lib_code}")
f"\n {self.lib_code}")
self.srcpath = src.name self.srcpath = src.name
self.libpath = libpath self.libpath = libpath
......
...@@ -5,22 +5,22 @@ This module provides runtime compilation support using NVIDIA's NVRTC API. ...@@ -5,22 +5,22 @@ This module provides runtime compilation support using NVIDIA's NVRTC API.
import logging import logging
__all__ = [ __all__ = ["NVRTCKernelAdapter", "TLNVRTCSourceWrapper", "NVRTCLibraryGenerator", "is_nvrtc_available", "check_nvrtc_available"]
'NVRTCKernelAdapter', 'TLNVRTCSourceWrapper', 'NVRTCLibraryGenerator', 'is_nvrtc_available',
'check_nvrtc_available'
]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Check if cuda-python is available # Check if cuda-python is available
is_nvrtc_available = False is_nvrtc_available = False
NVRTC_UNAVAILABLE_MESSAGE = ("cuda-python is not available, NVRTC backend cannot be used. " NVRTC_UNAVAILABLE_MESSAGE = (
"Please install cuda-python via `pip install cuda-python` " "cuda-python is not available, NVRTC backend cannot be used. "
"if you want to use the NVRTC backend.") "Please install cuda-python via `pip install cuda-python` "
"if you want to use the NVRTC backend."
)
try: try:
import cuda.bindings.driver as cuda # noqa: F401 import cuda.bindings.driver as cuda # noqa: F401
import cuda.bindings.nvrtc as nvrtc # noqa: F401 import cuda.bindings.nvrtc as nvrtc # noqa: F401
is_nvrtc_available = True is_nvrtc_available = True
except ImportError as e: except ImportError as e:
logger.debug(f"cuda-python import failed: {e}") logger.debug(f"cuda-python import failed: {e}")
......
...@@ -27,18 +27,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -27,18 +27,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
pymodule = None pymodule = None
kernels = {} kernels = {}
def __init__(self, def __init__(
params: list[KernelParam], self,
result_idx: list[int], params: list[KernelParam],
target: str | Target, result_idx: list[int],
func_or_mod: tir.PrimFunc | tvm.IRModule, target: str | Target,
host_mod: tvm.IRModule | None = None, func_or_mod: tir.PrimFunc | tvm.IRModule,
device_mod: tvm.IRModule | None = None, host_mod: tvm.IRModule | None = None,
device_kernel_source: str | None = None, device_mod: tvm.IRModule | None = None,
verbose: bool = False, device_kernel_source: str | None = None,
pass_configs: dict[str, Any] | None = None, verbose: bool = False,
compile_flags: list[str] | None = None): pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
):
check_nvrtc_available() check_nvrtc_available()
self.params = params self.params = params
...@@ -92,17 +93,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -92,17 +93,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self._post_init() self._post_init()
@classmethod @classmethod
def from_database(cls, def from_database(
params: list[KernelParam], cls,
result_idx: list[int], params: list[KernelParam],
target: str, result_idx: list[int],
func_or_mod: tir.PrimFunc | tvm.IRModule, target: str,
host_kernel_source: str, func_or_mod: tir.PrimFunc | tvm.IRModule,
device_kernel_source: str, host_kernel_source: str,
kernel_lib_path: str, device_kernel_source: str,
verbose: bool = False, kernel_lib_path: str,
pass_configs: dict[str, Any] | None = None, verbose: bool = False,
compile_flags: list[str] | None = None): pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
...@@ -183,8 +186,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -183,8 +186,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
return self.host_func return self.host_func
def _forward_from_prebuild_lib(self, *args, stream: int | None = None): def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel. """Low-level function to call the compiled CUDA kernel."""
"""
return self.pymodule.call(self.kernels, *args, stream=stream) return self.pymodule.call(self.kernels, *args, stream=stream)
def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None): def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None):
......
...@@ -13,6 +13,7 @@ Key responsibilities: ...@@ -13,6 +13,7 @@ Key responsibilities:
- Load compiled cubin and extract kernel handles - Load compiled cubin and extract kernel handles
- Manage library lifecycle (load/unload) - Manage library lifecycle (load/unload)
""" """
from __future__ import annotations from __future__ import annotations
import importlib import importlib
import logging import logging
...@@ -56,6 +57,7 @@ class NVRTCLibraryGenerator(LibraryGenerator): ...@@ -56,6 +57,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
culib: CUDA library handle (CUlibrary) culib: CUDA library handle (CUlibrary)
pymodule: Imported Python module containing call() function pymodule: Imported Python module containing call() function
""" """
host_func: str | None = None host_func: str | None = None
culib: cuda.CUlibrary | None = None culib: cuda.CUlibrary | None = None
pymodule: ModuleType | None = None pymodule: ModuleType | None = None
...@@ -131,10 +133,10 @@ class NVRTCLibraryGenerator(LibraryGenerator): ...@@ -131,10 +133,10 @@ class NVRTCLibraryGenerator(LibraryGenerator):
ctx = cuda.cuCtxGetCurrent()[1] ctx = cuda.cuCtxGetCurrent()[1]
if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS: if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS:
import torch import torch
torch.cuda.synchronize() torch.cuda.synchronize()
result, self.culib = cuda.cuLibraryLoadFromFile( result, self.culib = cuda.cuLibraryLoadFromFile(bytes(lib_path, "utf-8"), [], [], 0, [], [], 0)
bytes(lib_path, "utf-8"), [], [], 0, [], [], 0)
if result != cuda.CUresult.CUDA_SUCCESS: if result != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to load library: {lib_path}, error: {result}") raise RuntimeError(f"Failed to load library: {lib_path}, error: {result}")
...@@ -164,7 +166,8 @@ class NVRTCLibraryGenerator(LibraryGenerator): ...@@ -164,7 +166,8 @@ class NVRTCLibraryGenerator(LibraryGenerator):
target = self.target target = self.target
verbose = self.verbose verbose = self.verbose
if is_cuda_target(target): if is_cuda_target(target):
from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH) from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
libpath = src.name.replace(".cu", ".cubin") libpath = src.name.replace(".cu", ".cubin")
...@@ -195,13 +198,9 @@ class NVRTCLibraryGenerator(LibraryGenerator): ...@@ -195,13 +198,9 @@ class NVRTCLibraryGenerator(LibraryGenerator):
f"-D__CUDACC_VER_MAJOR__={__CUDACC_VER_MAJOR__}", f"-D__CUDACC_VER_MAJOR__={__CUDACC_VER_MAJOR__}",
] ]
if self.compile_flags: if self.compile_flags:
options += [ options += [item for flag in self.compile_flags for item in flag.split() if item not in options]
item for flag in self.compile_flags for item in flag.split()
if item not in options
]
cubin_bytes = compile_cuda( cubin_bytes = compile_cuda(self.lib_code, target_format="cubin", options=options, verbose=verbose)
self.lib_code, target_format="cubin", options=options, verbose=verbose)
with open(libpath, "wb") as f: with open(libpath, "wb") as f:
f.write(cubin_bytes) f.write(cubin_bytes)
...@@ -212,8 +211,7 @@ class NVRTCLibraryGenerator(LibraryGenerator): ...@@ -212,8 +211,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
self.libpath = libpath self.libpath = libpath
self.pypath = src.name.replace(".cu", ".py") self.pypath = src.name.replace(".cu", ".py")
if self.host_func is None: if self.host_func is None:
raise RuntimeError( raise RuntimeError("Host function is not set, please call update_host_func() first.")
"Host function is not set, please call update_host_func() first.")
with open(self.pypath, "w") as f: with open(self.pypath, "w") as f:
f.write(self.host_func) f.write(self.host_func)
else: else:
......
...@@ -12,6 +12,7 @@ Key design: ...@@ -12,6 +12,7 @@ Key design:
- Dict-based deduplication ensures TMA descriptors created only once - Dict-based deduplication ensures TMA descriptors created only once
- Generates pure Python using cuda.bindings.driver for zero C++ dependency - Generates pure Python using cuda.bindings.driver for zero C++ dependency
""" """
from __future__ import annotations from __future__ import annotations
from typing import Any, ClassVar from typing import Any, ClassVar
...@@ -21,8 +22,7 @@ from tvm.tir.stmt_functor import post_order_visit ...@@ -21,8 +22,7 @@ from tvm.tir.stmt_functor import post_order_visit
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.jit.adapter.wrapper import TLCUDASourceWrapper from tilelang.jit.adapter.wrapper import TLCUDASourceWrapper
from tilelang.jit.adapter.utils import (match_declare_kernel, pythonic_expr, from tilelang.jit.adapter.utils import match_declare_kernel, pythonic_expr, parse_function_call_args, parse_tma_descriptor_args
parse_function_call_args, parse_tma_descriptor_args)
PREDEF_HOST_FUNC_PY = """ PREDEF_HOST_FUNC_PY = """
from cuda.bindings.driver import ( from cuda.bindings.driver import (
...@@ -235,13 +235,15 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -235,13 +235,15 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
_generated_host_func: str | None = None _generated_host_func: str | None = None
def __init__(self, def __init__(
scheduled_ir_module: IRModule, self,
source: str, scheduled_ir_module: IRModule,
target: Target, source: str,
device_mod: IRModule | None = None, target: Target,
host_mod: IRModule | None = None, device_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None): host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None,
):
"""Initialize NVRTC wrapper with compiled IR modules. """Initialize NVRTC wrapper with compiled IR modules.
Args: Args:
...@@ -303,15 +305,16 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -303,15 +305,16 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
for param in self.prim_func.params: for param in self.prim_func.params:
if param in self.prim_func.buffer_map: if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param] buffer = self.prim_func.buffer_map[param]
function_args.append({ function_args.append(
"name": buffer.data.name, {
"type": "ctypes.c_void_p", "name": buffer.data.name,
}) "type": "ctypes.c_void_p",
}
)
elif isinstance(param, tvm.tir.Var): elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else: else:
raise ValueError( raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.")
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments # Add dynamic symbols as integer arguments
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]: if dyn_sym not in [arg["name"] for arg in function_args]:
...@@ -359,9 +362,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -359,9 +362,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
return (f"{name}.data_ptr()", arg_type) return (f"{name}.data_ptr()", arg_type)
return (name, arg_type) return (name, arg_type)
call_args = parse_function_call_args(declaration, function_args, function_params, call_args = parse_function_call_args(
desc_name_map, desc_name_var_map, declaration, function_args, function_params, desc_name_map, desc_name_var_map, transform_nvrtc_arg
transform_nvrtc_arg) )
for arg_name, arg_type in call_args: for arg_name, arg_type in call_args:
if arg_type == "ctypes.c_void_p": if arg_type == "ctypes.c_void_p":
...@@ -369,26 +372,28 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -369,26 +372,28 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
break break
# Store kernel info for second pass # Store kernel info for second pass
kernel_info_list.append({ kernel_info_list.append(
'function_name': function_name, {
'block_info': block_info, "function_name": function_name,
'grid_info': grid_info, "block_info": block_info,
'dynamic_smem_buf': dynamic_smem_buf, "grid_info": grid_info,
'call_args': call_args, "dynamic_smem_buf": dynamic_smem_buf,
'device_index': device_index, "call_args": call_args,
}) "device_index": device_index,
}
)
# Generate TMA descriptor initialization code once for all kernels # Generate TMA descriptor initialization code once for all kernels
kernel_launch_code += self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map) kernel_launch_code += self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map)
# Second pass: generate kernel launch code for each kernel # Second pass: generate kernel launch code for each kernel
for kernel_info in kernel_info_list: for kernel_info in kernel_info_list:
function_name = kernel_info['function_name'] function_name = kernel_info["function_name"]
block_info = kernel_info['block_info'] block_info = kernel_info["block_info"]
grid_info = kernel_info['grid_info'] grid_info = kernel_info["grid_info"]
dynamic_smem_buf = kernel_info['dynamic_smem_buf'] dynamic_smem_buf = kernel_info["dynamic_smem_buf"]
call_args = kernel_info['call_args'] call_args = kernel_info["call_args"]
device_index = kernel_info['device_index'] device_index = kernel_info["device_index"]
arg_names = ", ".join([arg[0] for arg in call_args]) arg_names = ", ".join([arg[0] for arg in call_args])
arg_types = ", ".join([arg[1] for arg in call_args]) arg_types = ", ".join([arg[1] for arg in call_args])
...@@ -399,23 +404,26 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -399,23 +404,26 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
kernel_launch_code += init_l2_persistent_map kernel_launch_code += init_l2_persistent_map
# Generate kernel launch code # Generate kernel launch code
kernel_launch_code += KERNEL_LAUNCH_FUNC_PY.format(function_name, kernel_launch_code += KERNEL_LAUNCH_FUNC_PY.format(
self._pythonic_expr(grid_info[0]), function_name,
self._pythonic_expr(grid_info[1]), self._pythonic_expr(grid_info[0]),
self._pythonic_expr(grid_info[2]), self._pythonic_expr(grid_info[1]),
self._pythonic_expr(block_info[0]), self._pythonic_expr(grid_info[2]),
self._pythonic_expr(block_info[1]), self._pythonic_expr(block_info[0]),
self._pythonic_expr(block_info[2]), self._pythonic_expr(block_info[1]),
smem_str, arg_names, arg_types, self._pythonic_expr(block_info[2]),
device_index) smem_str,
arg_names,
arg_types,
device_index,
)
# Reset L2 persistent map after all kernel execution # Reset L2 persistent map after all kernel execution
if has_l2_persistent_map: if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE_PY kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE_PY
# Wrap the kernel dispatch logic in an external C function # Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC_PY.format( host_func = PREDEF_HOST_FUNC_PY.format(repr(list(function_informations.keys())), def_args, kernel_launch_code)
repr(list(function_informations.keys())), def_args, kernel_launch_code)
return host_func return host_func
def generate_l2_persistent_map(self, function_name: str) -> str: def generate_l2_persistent_map(self, function_name: str) -> str:
...@@ -434,23 +442,21 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -434,23 +442,21 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
if function_name not in self.l2_persistent_map: if function_name not in self.l2_persistent_map:
return "" return ""
init_l2_persistent_map = "" init_l2_persistent_map = ""
for buffer_name, (hit_ratio, for buffer_name, (hit_ratio, size_in_bytes) in self.l2_persistent_map[function_name].items():
size_in_bytes) in self.l2_persistent_map[function_name].items():
# Get persisting_l2_cache_max_size # Get persisting_l2_cache_max_size
from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size
persisting_l2_cache_max_size = get_persisting_l2_cache_max_size() persisting_l2_cache_max_size = get_persisting_l2_cache_max_size()
try: try:
num_bytes = min(size_in_bytes, persisting_l2_cache_max_size) num_bytes = min(size_in_bytes, persisting_l2_cache_max_size)
except TypeError: except TypeError:
# as size_in_bytes may be a symbolic expression # as size_in_bytes may be a symbolic expression
num_bytes = persisting_l2_cache_max_size num_bytes = persisting_l2_cache_max_size
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC_PY.format( init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC_PY.format(buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes))
buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes))
return init_l2_persistent_map return init_l2_persistent_map
def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
"""Generate Python code to initialize TMA descriptors. """Generate Python code to initialize TMA descriptors.
TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects
...@@ -470,28 +476,43 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -470,28 +476,43 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
return tma_descriptor_init return tma_descriptor_init
# Parse TMA descriptor arguments using the common utility # Parse TMA descriptor arguments using the common utility
parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr)
desc_name_var_map, self._pythonic_expr)
# Generate Python code from parsed parameters # Generate Python code from parsed parameters
for params in parsed_params: for params in parsed_params:
if not params.is_img2col: if not params.is_img2col:
tma_descriptor_init += TMA_DESC_INIT_FUNC_PY.format( tma_descriptor_init += TMA_DESC_INIT_FUNC_PY.format(
params.handle_name, params.dtype, params.tensor_rank, params.global_address, params.handle_name,
params.dtype,
params.tensor_rank,
params.global_address,
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)), ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)),
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)), ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)),
", ".join(map(lambda x: f"cuuint32_t({x})", params.box_dim)), ", ".join(map(lambda x: f"cuuint32_t({x})", params.box_dim)),
", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)), ", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)),
params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) params.interleave,
params.swizzle,
params.l2_promotion,
params.oob_fill,
)
else: else:
tma_descriptor_init += TMA_IM2COL_DESC_INIT_FUNC_PY.format( tma_descriptor_init += TMA_IM2COL_DESC_INIT_FUNC_PY.format(
params.handle_name, params.dtype, params.tensor_rank, params.global_address, params.handle_name,
params.dtype,
params.tensor_rank,
params.global_address,
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)), ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)),
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)), ", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)),
", ".join(map(lambda x: f"cuuint32_t({x})", ", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)),
params.element_strides)), ", ".join(params.lower_corner), ", ".join(params.lower_corner),
", ".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel, ", ".join(params.upper_corner),
params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) params.smem_box_channel,
params.smem_box_pixel,
params.interleave,
params.swizzle,
params.l2_promotion,
params.oob_fill,
)
return tma_descriptor_init return tma_descriptor_init
...@@ -527,17 +548,14 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -527,17 +548,14 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
nonlocal function_params nonlocal function_params
if isinstance(node, tvm.tir.Call): if isinstance(node, tvm.tir.Call):
if not (hasattr(node, "op") and if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
return return
args = node.args args = node.args
if not args or args[0] != fn: if not args or args[0] != fn:
return return
if len(args) < 1 + param_cnt: if len(args) < 1 + param_cnt:
raise AssertionError( raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters")
"tvm_call_packed should have at least 1 argument and match device function parameters" function_params = args[1 : 1 + param_cnt]
)
function_params = args[1:1 + param_cnt]
post_order_visit(self.host_func.body, visitor) post_order_visit(self.host_func.body, visitor)
assert function_params is not None, "function_params should not be None" assert function_params is not None, "function_params should not be None"
......
from .metal import MetalKernelAdapter from .metal import MetalKernelAdapter
__all__ = ['MetalKernelAdapter'] __all__ = ["MetalKernelAdapter"]
...@@ -12,7 +12,6 @@ from tilelang.engine.param import KernelParam ...@@ -12,7 +12,6 @@ from tilelang.engine.param import KernelParam
class MetalKernelAdapter(BaseKernelAdapter): class MetalKernelAdapter(BaseKernelAdapter):
def __init__( def __init__(
self, self,
params: list[KernelParam], params: list[KernelParam],
...@@ -28,10 +27,10 @@ class MetalKernelAdapter(BaseKernelAdapter): ...@@ -28,10 +27,10 @@ class MetalKernelAdapter(BaseKernelAdapter):
): ):
self.kernel_global_source = kernel_global_source self.kernel_global_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
func_name = func_or_mod.attrs['global_symbol'] func_name = func_or_mod.attrs["global_symbol"]
else: else:
func_name = func_or_mod.__name__ func_name = func_or_mod.__name__
self.kernel_name = func_name + '_kernel' self.kernel_name = func_name + "_kernel"
self.verbose = verbose self.verbose = verbose
self.block_info = [1, 1, 1] self.block_info = [1, 1, 1]
...@@ -39,7 +38,7 @@ class MetalKernelAdapter(BaseKernelAdapter): ...@@ -39,7 +38,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
for var, func in device_mod.functions.items(): for var, func in device_mod.functions.items():
assert var.name_hint == self.kernel_name assert var.name_hint == self.kernel_name
thread_extent = func.attrs['thread_extent'] thread_extent = func.attrs["thread_extent"]
for tag, extent in thread_extent.items(): for tag, extent in thread_extent.items():
if "threadIdx" in tag: if "threadIdx" in tag:
self.block_info["xyz".index(tag[-1])] = extent self.block_info["xyz".index(tag[-1])] = extent
...@@ -47,7 +46,7 @@ class MetalKernelAdapter(BaseKernelAdapter): ...@@ -47,7 +46,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
self.grid_info["xyz".index(tag[-1])] = extent self.grid_info["xyz".index(tag[-1])] = extent
break break
else: else:
raise AssertionError(f'no kernel with name {func_name}') raise AssertionError(f"no kernel with name {func_name}")
# print(self.block_info, self.grid_info) # print(self.block_info, self.grid_info)
super().__init__(func_or_mod, result_idx=result_idx, params=params) super().__init__(func_or_mod, result_idx=result_idx, params=params)
...@@ -55,15 +54,12 @@ class MetalKernelAdapter(BaseKernelAdapter): ...@@ -55,15 +54,12 @@ class MetalKernelAdapter(BaseKernelAdapter):
_kernel = None _kernel = None
def _convert_torch_func(self) -> Callable: def _convert_torch_func(self) -> Callable:
if self._kernel is None: if self._kernel is None:
_kernel = getattr(torch.mps.compile_shader(self.kernel_global_source), self.kernel_name) _kernel = getattr(torch.mps.compile_shader(self.kernel_global_source), self.kernel_name)
_threads = [x * y for (x, y) in zip(self.block_info, self.grid_info)] _threads = [x * y for (x, y) in zip(self.block_info, self.grid_info)]
@wraps(_kernel) @wraps(_kernel)
def launcher(*args: torch.Tensor): def launcher(*args: torch.Tensor):
return _kernel( return _kernel(
*args, *args,
threads=_threads, threads=_threads,
......
...@@ -5,6 +5,7 @@ via light-weight callables so that, when the wrapped function is invoked, ...@@ -5,6 +5,7 @@ via light-weight callables so that, when the wrapped function is invoked,
the execution observes the same stream context as the active Torch code. the execution observes the same stream context as the active Torch code.
On non-CUDA builds, the stream/device fall back to 0/CPU semantics. On non-CUDA builds, the stream/device fall back to 0/CPU semantics.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Callable, Any from typing import Callable, Any
...@@ -31,6 +32,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -31,6 +32,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
- The stream pointer returned is a raw CUDA stream handle compatible with - The stream pointer returned is a raw CUDA stream handle compatible with
TVM's device API; on CPU or when CUDA is unavailable, we return 0. TVM's device API; on CPU or when CUDA is unavailable, we return 0.
""" """
# Class attributes to store compiled kernel information # Class attributes to store compiled kernel information
target: str | Target = "cuda" target: str | Target = "cuda"
ir_module: tvm.IRModule | None = None ir_module: tvm.IRModule | None = None
...@@ -51,19 +53,21 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -51,19 +53,21 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map: dict[tir.Var, tuple[int, int, int]] | None = None dynamic_symbolic_map: dict[tir.Var, tuple[int, int, int]] | None = None
# Stream/device functors are inherited from BaseKernelAdapter # Stream/device functors are inherited from BaseKernelAdapter
def __init__(self, def __init__(
params: list[KernelParam], self,
result_idx: list[int], params: list[KernelParam],
target: str | Target, result_idx: list[int],
func_or_mod: tir.PrimFunc | tvm.IRModule, target: str | Target,
host_mod: tvm.IRModule | None = None, func_or_mod: tir.PrimFunc | tvm.IRModule,
device_mod: tvm.IRModule | None = None, host_mod: tvm.IRModule | None = None,
rt_mod: tvm.runtime.Module | None = None, device_mod: tvm.IRModule | None = None,
host_kernel_source: str | None = None, rt_mod: tvm.runtime.Module | None = None,
device_kernel_source: str | None = None, host_kernel_source: str | None = None,
verbose: bool = False, device_kernel_source: str | None = None,
pass_configs: dict[str, Any] | None = None, verbose: bool = False,
compile_flags: list[str] | None = None): pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
...@@ -113,15 +117,13 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -113,15 +117,13 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
if param in buffer_map: if param in buffer_map:
buffer = buffer_map[param] buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape): for j, shape in enumerate(buffer.shape):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params):
(shape not in params)):
dynamic_symbolic_map[shape] = (0, i, j) dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params): for i, param in enumerate(params):
if param in buffer_map: if param in buffer_map:
buffer = buffer_map[param] buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides): for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params):
(stride not in params)):
dynamic_symbolic_map[stride] = (1, i, j) dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
...@@ -197,8 +199,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -197,8 +199,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
# Validate input count strictly # Validate input count strictly
expected_inputs = len(self.params) - len(self.result_idx) expected_inputs = len(self.params) - len(self.result_idx)
if len(inputs) != expected_inputs: if len(inputs) != expected_inputs:
raise ValueError( raise ValueError(f"Kernel expected {expected_inputs} inputs, but {len(inputs)} are provided.")
f"Kernel expected {expected_inputs} inputs, but {len(inputs)} are provided.")
# Resolve the device used for outputs. Prefer the first tensor input's device # Resolve the device used for outputs. Prefer the first tensor input's device
# if available, otherwise use PyTorch's current device. # if available, otherwise use PyTorch's current device.
...@@ -217,17 +218,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -217,17 +218,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
for s in param_shapes[i]: for s in param_shapes[i]:
if isinstance(s, tir.Var): if isinstance(s, tir.Var):
for key in dynamic_symbolic_map: for key in dynamic_symbolic_map:
if (str(s) == str(key)): if str(s) == str(key):
ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[ ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[key]
key]
if ref_id == 2: if ref_id == 2:
shape.append(inputs[ref_tensor_idx]) shape.append(inputs[ref_tensor_idx])
elif ref_id == 0: elif ref_id == 0:
shape.append( shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
tensor_list[ref_tensor_idx].shape[ref_shape_idx])
elif ref_id == 1: elif ref_id == 1:
shape.append( shape.append(tensor_list[ref_tensor_idx].stride()[ref_shape_idx])
tensor_list[ref_tensor_idx].stride()[ref_shape_idx])
else: # Already converted to Python int during initialization else: # Already converted to Python int during initialization
shape.append(s) shape.append(s)
...@@ -235,11 +233,11 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -235,11 +233,11 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
out_device = current_device_functor() out_device = current_device_functor()
if len(shape) == 0: if len(shape) == 0:
param_name = self.params[i].name if hasattr(self.params[i], param_name = self.params[i].name if hasattr(self.params[i], "name") else f"parameter_{i}"
'name') else f'parameter_{i}'
raise ValueError( raise ValueError(
f"Cannot create output tensor (name={param_name}) - 0-dimensional tensors are not supported. " f"Cannot create output tensor (name={param_name}) - 0-dimensional tensors are not supported. "
f"Expected shape: {shape}") f"Expected shape: {shape}"
)
tensor = torch.empty(*shape, dtype=dtype, device=out_device) tensor = torch.empty(*shape, dtype=dtype, device=out_device)
else: else:
tensor = inputs[ins_idx] tensor = inputs[ins_idx]
...@@ -256,17 +254,19 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -256,17 +254,19 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
return func return func
@classmethod @classmethod
def from_database(cls, def from_database(
params: list[TensorType], cls,
result_idx: list[int], params: list[TensorType],
target: str, result_idx: list[int],
func_or_mod: tir.PrimFunc | tvm.IRModule, target: str,
host_kernel_source: str, func_or_mod: tir.PrimFunc | tvm.IRModule,
device_kernel_source: str, host_kernel_source: str,
kernel_lib_path: str, device_kernel_source: str,
verbose: bool = False, kernel_lib_path: str,
pass_configs: dict[str, Any] | None = None, verbose: bool = False,
compile_flags: list[str] | None = None): pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
......
...@@ -70,7 +70,6 @@ def get_annotated_mod( ...@@ -70,7 +70,6 @@ def get_annotated_mod(
target_host: str | Target | None = None, target_host: str | Target | None = None,
model_type: Literal["device", "host", "all"] = "all", model_type: Literal["device", "host", "all"] = "all",
) -> IRModule | tuple[IRModule, IRModule]: ) -> IRModule | tuple[IRModule, IRModule]:
# Validate model_type early # Validate model_type early
if model_type not in {"device", "host", "all"}: if model_type not in {"device", "host", "all"}:
raise ValueError(f"Invalid model type: {model_type}") raise ValueError(f"Invalid model type: {model_type}")
...@@ -95,21 +94,15 @@ def get_annotated_mod( ...@@ -95,21 +94,15 @@ def get_annotated_mod(
# Define dispatch dictionary for different model types # Define dispatch dictionary for different model types
dispatch = { dispatch = {
"device": "device": lambda m: tir.transform.Filter(_is_device_call)(m),
lambda m: tir.transform.Filter(_is_device_call)(m), "host": lambda m: tir.transform.Filter(_is_host_call)(m),
"host": "all": lambda m: (tir.transform.Filter(_is_device_call)(m), tir.transform.Filter(_is_host_call)(m)),
lambda m: tir.transform.Filter(_is_host_call)(m),
"all":
lambda m: (tir.transform.Filter(_is_device_call)(m), tir.transform.Filter(_is_host_call)
(m)),
} }
return dispatch[model_type](mod) return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.PrimExpr, def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None, ignore_cast: bool = False) -> str:
dtype_map: dict[str, str] | None = None,
ignore_cast: bool = False) -> str:
""" """
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
...@@ -168,9 +161,23 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, ...@@ -168,9 +161,23 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
s = f"({type_str}){value_str}" s = f"({type_str}){value_str}"
p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE) p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE)
elif isinstance( elif isinstance(
node, node,
(tvm.tir.Mul, tvm.tir.FloorDiv, tvm.tir.Add, tvm.tir.Sub, tvm.tir.FloorMod, tvm.tir.LT, (
tvm.tir.LE, tvm.tir.GT, tvm.tir.GE, tvm.tir.EQ, tvm.tir.NE, tvm.tir.And, tvm.tir.Or)): tvm.tir.Mul,
tvm.tir.FloorDiv,
tvm.tir.Add,
tvm.tir.Sub,
tvm.tir.FloorMod,
tvm.tir.LT,
tvm.tir.LE,
tvm.tir.GT,
tvm.tir.GE,
tvm.tir.EQ,
tvm.tir.NE,
tvm.tir.And,
tvm.tir.Or,
),
):
op_map = { op_map = {
tvm.tir.Mul: "*", tvm.tir.Mul: "*",
tvm.tir.FloorDiv: "/", tvm.tir.FloorDiv: "/",
...@@ -222,10 +229,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, ...@@ -222,10 +229,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
return next(iter(node_to_result_map[expr]), "") return next(iter(node_to_result_map[expr]), "")
def maybe_desc_name(name: str, def maybe_desc_name(name: str, matches: list[str], i: int, desc_name_map: dict[str, str] | None = None) -> bool:
matches: list[str],
i: int,
desc_name_map: dict[str, str] | None = None) -> bool:
""" """
Check if a parameter name corresponds to a TMA descriptor. Check if a parameter name corresponds to a TMA descriptor.
...@@ -290,8 +294,7 @@ def parse_function_call_args( ...@@ -290,8 +294,7 @@ def parse_function_call_args(
else: else:
call_args.append(match) call_args.append(match)
if desc_name_var_map is not None and function_params is not None: if desc_name_var_map is not None and function_params is not None:
assert len(call_args) <= len(function_params), \ assert len(call_args) <= len(function_params), f"Too many arguments: {len(call_args)} > {len(function_params)}"
f"Too many arguments: {len(call_args)} > {len(function_params)}"
desc_name_var_map[match] = function_params[len(call_args) - 1] desc_name_var_map[match] = function_params[len(call_args) - 1]
return call_args return call_args
...@@ -300,12 +303,7 @@ def parse_function_call_args( ...@@ -300,12 +303,7 @@ def parse_function_call_args(
class TMADescriptorParams: class TMADescriptorParams:
"""Parsed TMA descriptor parameters.""" """Parsed TMA descriptor parameters."""
def __init__(self, def __init__(self, handle_name: str, dtype: str, tensor_rank: int, global_address: Any, is_img2col: bool = False):
handle_name: str,
dtype: str,
tensor_rank: int,
global_address: Any,
is_img2col: bool = False):
self.handle_name = handle_name self.handle_name = handle_name
self.dtype = dtype self.dtype = dtype
self.tensor_rank = tensor_rank self.tensor_rank = tensor_rank
...@@ -355,22 +353,19 @@ def parse_tma_descriptor_args( ...@@ -355,22 +353,19 @@ def parse_tma_descriptor_args(
results = [] results = []
for handle_name, _ in desc_name_map.items(): for handle_name, _ in desc_name_map.items():
assert handle_name in desc_name_var_map, \ assert handle_name in desc_name_var_map, f"Handle name {handle_name} not found in desc_name_var_map"
f"Handle name {handle_name} not found in desc_name_var_map"
desc_var = desc_name_var_map[handle_name] desc_var = desc_name_var_map[handle_name]
assert desc_var in tma_descriptor_args, \ assert desc_var in tma_descriptor_args, f"TMA descriptor {desc_var} not found in {tma_descriptor_args}"
f"TMA descriptor {desc_var} not found in {tma_descriptor_args}"
args = tma_descriptor_args[desc_var] args = tma_descriptor_args[desc_var]
# Skip __tvm_tensormap_create_tiled and second element (like CUDA version) # Skip __tvm_tensormap_create_tiled and second element (like CUDA version)
if len(args) < 3: if len(args) < 3:
raise ValueError( raise ValueError(f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
tma_create_str, _, dtype, tensor_rank, global_address, *remaining_args = args tma_create_str, _, dtype, tensor_rank, global_address, *remaining_args = args
is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col") is_img2col = tma_create_str.value == "__tvm_tensormap_create_im2col"
# Convert basic fields # Convert basic fields
dtype = pythonic_expr_func(dtype) dtype = pythonic_expr_func(dtype)
...@@ -386,60 +381,45 @@ def parse_tma_descriptor_args( ...@@ -386,60 +381,45 @@ def parse_tma_descriptor_args(
# Tiled mode # Tiled mode
expected_args_len = 4 * tensor_rank + 4 expected_args_len = 4 * tensor_rank + 4
if len(remaining_args) < expected_args_len: if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " raise ValueError(
f"expected {expected_args_len} for tensor_rank {tensor_rank}") f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}"
)
# Extract dimensions and strides # Extract dimensions and strides
params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]] params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]]
params.global_stride = [ params.global_stride = [pythonic_expr_func(i) for i in remaining_args[tensor_rank : 2 * tensor_rank]]
pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank] params.box_dim = [pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank : 3 * tensor_rank]]
] params.element_strides = [pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank : 4 * tensor_rank]]
params.box_dim = [
pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank]
]
params.element_strides = [
pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank]
]
# Extract remaining parameters # Extract remaining parameters
try: try:
interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank:4 * interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank : 4 * tensor_rank + 4]
tensor_rank + 4]
params.interleave = pythonic_expr_func(interleave) params.interleave = pythonic_expr_func(interleave)
params.swizzle = pythonic_expr_func(swizzle) params.swizzle = pythonic_expr_func(swizzle)
params.l2_promotion = pythonic_expr_func(l2_promotion) params.l2_promotion = pythonic_expr_func(l2_promotion)
params.oob_fill = pythonic_expr_func(oob_fill) params.oob_fill = pythonic_expr_func(oob_fill)
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError("Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)") from e
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
else: else:
# Im2col mode # Im2col mode
expected_args_len = 5 * tensor_rank + 2 expected_args_len = 5 * tensor_rank + 2
if len(remaining_args) < expected_args_len: if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " raise ValueError(
f"expected {expected_args_len} for tensor_rank {tensor_rank}") f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}"
)
# Extract dimensions and strides # Extract dimensions and strides
params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]] params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]]
params.global_stride = [ params.global_stride = [pythonic_expr_func(i) for i in remaining_args[tensor_rank : 2 * tensor_rank]]
pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank] params.element_strides = [pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank : 3 * tensor_rank]]
] params.lower_corner = [pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank : 4 * tensor_rank - 2]]
params.element_strides = [ params.upper_corner = [pythonic_expr_func(i) for i in remaining_args[4 * tensor_rank - 2 : 5 * tensor_rank - 4]]
pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank]
]
params.lower_corner = [
pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
]
params.upper_corner = [
pythonic_expr_func(i)
for i in remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
]
# Extract remaining parameters # Extract remaining parameters
try: try:
smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = \ smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = remaining_args[
remaining_args[5 * tensor_rank - 4:5 * tensor_rank + 2] 5 * tensor_rank - 4 : 5 * tensor_rank + 2
]
params.smem_box_pixel = pythonic_expr_func(smem_box_pixel) params.smem_box_pixel = pythonic_expr_func(smem_box_pixel)
params.smem_box_channel = pythonic_expr_func(smem_box_channel) params.smem_box_channel = pythonic_expr_func(smem_box_channel)
params.interleave = pythonic_expr_func(interleave) params.interleave = pythonic_expr_func(interleave)
......
...@@ -4,9 +4,18 @@ from tilelang import tvm as tvm ...@@ -4,9 +4,18 @@ from tilelang import tvm as tvm
from typing import Any from typing import Any
from tvm import IRModule from tvm import IRModule
from tvm.target import Target from tvm.target import Target
from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, from .utils import (
is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr, is_metal_target,
parse_function_call_args, parse_tma_descriptor_args) match_declare_kernel,
match_declare_kernel_cpu,
is_cuda_target,
is_hip_target,
is_cpu_target,
get_annotated_mod,
pythonic_expr,
parse_function_call_args,
parse_tma_descriptor_args,
)
import re import re
import logging import logging
import textwrap import textwrap
...@@ -129,7 +138,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """ ...@@ -129,7 +138,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """
class BaseWrapper(ABC): class BaseWrapper(ABC):
@abstractmethod @abstractmethod
def wrap(self, *args, **kwargs): def wrap(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
...@@ -163,13 +171,15 @@ class TLCUDASourceWrapper: ...@@ -163,13 +171,15 @@ class TLCUDASourceWrapper:
host_mod: IRModule | None = None host_mod: IRModule | None = None
pass_configs: dict[str, Any] | None = None pass_configs: dict[str, Any] | None = None
def __init__(self, def __init__(
scheduled_ir_module: IRModule, self,
source: str, scheduled_ir_module: IRModule,
target: Target, source: str,
device_mod: IRModule | None = None, target: Target,
host_mod: IRModule | None = None, device_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None): host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None,
):
self.mod = scheduled_ir_module self.mod = scheduled_ir_module
self.target = target self.target = target
self.source = source self.source = source
...@@ -211,15 +221,16 @@ class TLCUDASourceWrapper: ...@@ -211,15 +221,16 @@ class TLCUDASourceWrapper:
for param in self.prim_func.params: for param in self.prim_func.params:
if param in self.prim_func.buffer_map: if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param] buffer = self.prim_func.buffer_map[param]
function_args.append({ function_args.append(
"name": buffer.data.name, {
"type": self._lookup_type(buffer.dtype) + "* __restrict__", "name": buffer.data.name,
}) "type": self._lookup_type(buffer.dtype) + "* __restrict__",
}
)
elif isinstance(param, tvm.tir.Var): elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else: else:
raise ValueError( raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.")
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments # Add dynamic symbols as integer arguments
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]: if dyn_sym not in [arg["name"] for arg in function_args]:
...@@ -256,38 +267,40 @@ class TLCUDASourceWrapper: ...@@ -256,38 +267,40 @@ class TLCUDASourceWrapper:
# Identify the start of the function body to insert arguments # Identify the start of the function body to insert arguments
index = code.index("{", index) index = code.index("{", index)
block_str = f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})" block_str = (
grid_str = f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})" f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})"
)
grid_str = (
f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})"
)
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name) init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map kernel_launch_code += init_l2_persistent_map
if self.use_cooperative_groups[function_name]: if self.use_cooperative_groups[function_name]:
args_list = parse_function_call_args(declaration, function_args, function_params, args_list = parse_function_call_args(declaration, function_args, function_params, desc_name_map, desc_name_var_map)
desc_name_map, desc_name_var_map) assert len(function_params) == len(args_list), (
assert len(function_params) == len( f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
args_list )
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
args_array = [f"(void*)&{arg}" for arg in args_list] args_array = [f"(void*)&{arg}" for arg in args_list]
call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n" call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n"
kernel_launch_code += call_args kernel_launch_code += call_args
# Using cudaLaunchCooperativeKernel to launch the kernel # Using cudaLaunchCooperativeKernel to launch the kernel
kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format( kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format(
function_name, grid_str, block_str, function_name + "_args", smem_str) function_name, grid_str, block_str, function_name + "_args", smem_str
)
else: else:
args_list = parse_function_call_args(declaration, function_args, function_params, args_list = parse_function_call_args(declaration, function_args, function_params, desc_name_map, desc_name_var_map)
desc_name_map, desc_name_var_map) assert len(function_params) == len(args_list), (
assert len(function_params) == len( f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
args_list )
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
call_args = ", ".join(args_list) call_args = ", ".join(args_list)
kernel_launch_code += f"\t{function_name}<<<{grid_str}, {block_str}, {smem_str}, stream>>>({call_args});\n" kernel_launch_code += f"\t{function_name}<<<{grid_str}, {block_str}, {smem_str}, stream>>>({call_args});\n"
kernel_launch_code += f"\tTILELANG_CHECK_LAST_ERROR(\"{function_name}\");\n" kernel_launch_code += f'\tTILELANG_CHECK_LAST_ERROR("{function_name}");\n'
if has_l2_persistent_map: if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE
init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map, init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map)
desc_name_var_map)
kernel_launch_code = init_tma_descriptor_args + kernel_launch_code kernel_launch_code = init_tma_descriptor_args + kernel_launch_code
# Wrap the kernel dispatch logic in an external C function # Wrap the kernel dispatch logic in an external C function
...@@ -298,46 +311,63 @@ class TLCUDASourceWrapper: ...@@ -298,46 +311,63 @@ class TLCUDASourceWrapper:
if function_name not in self.l2_persistent_map: if function_name not in self.l2_persistent_map:
return "" return ""
init_l2_persistent_map = "" init_l2_persistent_map = ""
for buffer_name, (hit_ratio, for buffer_name, (hit_ratio, size_in_bytes) in self.l2_persistent_map[function_name].items():
size_in_bytes) in self.l2_persistent_map[function_name].items():
# get persisting_l2_cache_max_size # get persisting_l2_cache_max_size
from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size
persisting_l2_cache_max_size = get_persisting_l2_cache_max_size() persisting_l2_cache_max_size = get_persisting_l2_cache_max_size()
try: try:
num_bytes = min(size_in_bytes, persisting_l2_cache_max_size) num_bytes = min(size_in_bytes, persisting_l2_cache_max_size)
except Exception: except Exception:
# as size_in_bytes maybe a symbolic expression # as size_in_bytes maybe a symbolic expression
num_bytes = persisting_l2_cache_max_size num_bytes = persisting_l2_cache_max_size
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format( init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes))
buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes))
return init_l2_persistent_map return init_l2_persistent_map
def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
tma_descripter_init = "" tma_descripter_init = ""
if self.tma_descriptor_args is None: if self.tma_descriptor_args is None:
return tma_descripter_init return tma_descripter_init
# Parse TMA descriptor arguments using the common utility # Parse TMA descriptor arguments using the common utility
parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr)
desc_name_var_map, self._pythonic_expr)
# Generate C++ code from parsed parameters # Generate C++ code from parsed parameters
for params in parsed_params: for params in parsed_params:
if not params.is_img2col: if not params.is_img2col:
tma_descripter_init += TMA_DESC_INIT_FUNC.format( tma_descripter_init += TMA_DESC_INIT_FUNC.format(
params.handle_name, params.dtype, params.tensor_rank, params.global_address, params.handle_name,
",".join(params.global_dim), ",".join(params.global_stride), params.dtype,
",".join(params.box_dim), ",".join(params.element_strides), params.interleave, params.tensor_rank,
params.swizzle, params.l2_promotion, params.oob_fill) params.global_address,
",".join(params.global_dim),
",".join(params.global_stride),
",".join(params.box_dim),
",".join(params.element_strides),
params.interleave,
params.swizzle,
params.l2_promotion,
params.oob_fill,
)
else: else:
tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format( tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format(
params.handle_name, params.dtype, params.tensor_rank, params.global_address, params.handle_name,
",".join(params.global_dim), ",".join(params.global_stride), params.dtype,
",".join(params.element_strides), ",".join(params.lower_corner), params.tensor_rank,
",".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel, params.global_address,
params.interleave, params.swizzle, params.l2_promotion, params.oob_fill) ",".join(params.global_dim),
",".join(params.global_stride),
",".join(params.element_strides),
",".join(params.lower_corner),
",".join(params.upper_corner),
params.smem_box_channel,
params.smem_box_pixel,
params.interleave,
params.swizzle,
params.l2_promotion,
params.oob_fill,
)
return tma_descripter_init return tma_descripter_init
...@@ -347,9 +377,8 @@ class TLCUDASourceWrapper: ...@@ -347,9 +377,8 @@ class TLCUDASourceWrapper:
device_mod, host_mod = get_annotated_mod(self.mod, self.target) device_mod, host_mod = get_annotated_mod(self.mod, self.target)
self.device_mod = device_mod self.device_mod = device_mod
self.host_mod = host_mod self.host_mod = host_mod
assert (len(self.device_mod.functions) assert len(self.device_mod.functions) >= 1, "Device module should have at least one function."
>= 1), "Device module should have at least one function." assert len(self.host_mod.functions) == 1, "Only support one function in host module."
assert (len(self.host_mod.functions) == 1), "Only support one function in host module."
block_info_map = {} block_info_map = {}
grid_info_map = {} grid_info_map = {}
...@@ -438,8 +467,7 @@ class TLCUDASourceWrapper: ...@@ -438,8 +467,7 @@ class TLCUDASourceWrapper:
for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items():
if dynamic_smem_buf is not None: if dynamic_smem_buf is not None:
# Format the cudaFuncSetAttribute call for dynamic shared memory # Format the cudaFuncSetAttribute call for dynamic shared memory
call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY.format( call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY.format(function_name, dynamic_smem_buf)
function_name, dynamic_smem_buf)
# Format the initialization function using the call_str # Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str) init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs return init_funcs
...@@ -466,17 +494,14 @@ class TLCUDASourceWrapper: ...@@ -466,17 +494,14 @@ class TLCUDASourceWrapper:
def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
nonlocal function_params nonlocal function_params
if isinstance(node, tvm.tir.Call): if isinstance(node, tvm.tir.Call):
if not (hasattr(node, "op") and if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
return return
args = node.args args = node.args
if not args or args[0] != fn: if not args or args[0] != fn:
return return
if len(args) < 1 + param_cnt: if len(args) < 1 + param_cnt:
raise AssertionError( raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters")
"tvm_call_packed should have at least 1 argument and match device function parameters" function_params = args[1 : 1 + param_cnt]
)
function_params = args[1:1 + param_cnt]
post_order_visit(self.host_func.body, visitor) post_order_visit(self.host_func.body, visitor)
assert function_params is not None, "function_params should not be None" assert function_params is not None, "function_params should not be None"
...@@ -564,13 +589,15 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): ...@@ -564,13 +589,15 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
"uchar": "uint8_t", "uchar": "uint8_t",
} }
def __init__(self, def __init__(
scheduled_ir_module: IRModule, self,
source: str, scheduled_ir_module: IRModule,
target: Target, source: str,
device_mod: IRModule | None = None, target: Target,
host_mod: IRModule | None = None, device_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None): host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None,
):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def get_init_func(self): def get_init_func(self):
...@@ -580,8 +607,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): ...@@ -580,8 +607,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items():
if dynamic_smem_buf is not None: if dynamic_smem_buf is not None:
# Format the cudaFuncSetAttribute call for dynamic shared memory # Format the cudaFuncSetAttribute call for dynamic shared memory
call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP.format( call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP.format(function_name, dynamic_smem_buf)
function_name, dynamic_smem_buf)
# Format the initialization function using the call_str # Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str) init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs return init_funcs
...@@ -623,13 +649,15 @@ class TLCPUSourceWrapper: ...@@ -623,13 +649,15 @@ class TLCPUSourceWrapper:
host_mod: IRModule | None = None host_mod: IRModule | None = None
pass_configs: dict[str, Any] | None = None pass_configs: dict[str, Any] | None = None
def __init__(self, def __init__(
scheduled_ir_module: IRModule, self,
source: str, scheduled_ir_module: IRModule,
target: Target, source: str,
device_mod: IRModule | None = None, target: Target,
host_mod: IRModule | None = None, device_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None): host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None,
):
self.mod = scheduled_ir_module self.mod = scheduled_ir_module
self.target = target self.target = target
self.source = source self.source = source
...@@ -658,15 +686,16 @@ class TLCPUSourceWrapper: ...@@ -658,15 +686,16 @@ class TLCPUSourceWrapper:
for param in self.prim_func.params: for param in self.prim_func.params:
if param in self.prim_func.buffer_map: if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param] buffer = self.prim_func.buffer_map[param]
function_args.append({ function_args.append(
"name": buffer.name, {
"type": self._lookup_type(buffer.dtype) + "*", "name": buffer.name,
}) "type": self._lookup_type(buffer.dtype) + "*",
}
)
elif isinstance(param, tvm.tir.Var): elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else: else:
raise ValueError( raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.")
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments # Add dynamic symbols as integer arguments
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set: for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)}) function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)})
...@@ -686,7 +715,6 @@ class TLCPUSourceWrapper: ...@@ -686,7 +715,6 @@ class TLCPUSourceWrapper:
_call_str = """""" _call_str = """"""
for function_name, _ in function_informations.items(): for function_name, _ in function_informations.items():
# Find the location of the global kernel function in the code # Find the location of the global kernel function in the code
index = match_declare_kernel_cpu(code, function_name + "(") index = match_declare_kernel_cpu(code, function_name + "(")
...@@ -706,8 +734,8 @@ class TLCPUSourceWrapper: ...@@ -706,8 +734,8 @@ class TLCPUSourceWrapper:
def parse_source_information(self): def parse_source_information(self):
with tvm.transform.PassContext(opt_level=3, config=self.pass_configs): with tvm.transform.PassContext(opt_level=3, config=self.pass_configs):
device_mod, host_mod = get_annotated_mod(self.mod, self.target) device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) >= 1), "Device module should have at least one function." assert len(device_mod.functions) >= 1, "Device module should have at least one function."
assert (len(host_mod.functions) == 1), "Only support one function in host module." assert len(host_mod.functions) == 1, "Only support one function in host module."
function_names = [] function_names = []
for g_var, _ in device_mod.functions.items(): for g_var, _ in device_mod.functions.items():
...@@ -767,14 +795,15 @@ class TLCPUSourceWrapper: ...@@ -767,14 +795,15 @@ class TLCPUSourceWrapper:
class TLMetalSourceWrapper: class TLMetalSourceWrapper:
def __init__(
def __init__(self, self,
scheduled_ir_module: IRModule, scheduled_ir_module: IRModule,
source: str, source: str,
target: Target, target: Target,
device_mod: IRModule | None = None, device_mod: IRModule | None = None,
host_mod: IRModule | None = None, host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None): pass_configs: dict[str, Any] | None = None,
):
self.mod = scheduled_ir_module self.mod = scheduled_ir_module
self.target = target self.target = target
self.source = source self.source = source
...@@ -792,6 +821,7 @@ class TLWrapper(BaseWrapper): ...@@ -792,6 +821,7 @@ class TLWrapper(BaseWrapper):
""" """
A wrapper class for the TileLang backend. A wrapper class for the TileLang backend.
""" """
device_mod: IRModule | None = None device_mod: IRModule | None = None
host_mod: IRModule | None = None host_mod: IRModule | None = None
pass_configs: dict[str, Any] | None = None pass_configs: dict[str, Any] | None = None
...@@ -836,12 +866,12 @@ class TLWrapper(BaseWrapper): ...@@ -836,12 +866,12 @@ class TLWrapper(BaseWrapper):
target=self.target, target=self.target,
device_mod=self.device_mod, device_mod=self.device_mod,
host_mod=self.host_mod, host_mod=self.host_mod,
pass_configs=self.pass_configs) pass_configs=self.pass_configs,
)
return wrapper.lib_code return wrapper.lib_code
class TLPyWrapper(TLWrapper): class TLPyWrapper(TLWrapper):
def __init__(self, target: Target): def __init__(self, target: Target):
super().__init__(target) super().__init__(target)
...@@ -849,6 +879,7 @@ class TLPyWrapper(TLWrapper): ...@@ -849,6 +879,7 @@ class TLPyWrapper(TLWrapper):
# assert self.scheduled_ir_module is not None, "Please assign optimized module first." # assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if is_cuda_target(self.target): if is_cuda_target(self.target):
from tilelang.jit.adapter.nvrtc import TLNVRTCSourceWrapper from tilelang.jit.adapter.nvrtc import TLNVRTCSourceWrapper
wrapper_class = TLNVRTCSourceWrapper wrapper_class = TLNVRTCSourceWrapper
else: else:
raise ValueError(f"Unsupported target for NVRTC backend: {self.target}") raise ValueError(f"Unsupported target for NVRTC backend: {self.target}")
...@@ -858,5 +889,6 @@ class TLPyWrapper(TLWrapper): ...@@ -858,5 +889,6 @@ class TLPyWrapper(TLWrapper):
target=self.target, target=self.target,
device_mod=self.device_mod, device_mod=self.device_mod,
host_mod=self.host_mod, host_mod=self.host_mod,
pass_configs=self.pass_configs) pass_configs=self.pass_configs,
)
return wrapper.host_func, wrapper.function_names return wrapper.host_func, wrapper.function_names
...@@ -46,6 +46,7 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T ...@@ -46,6 +46,7 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T
# Drop NVRTC if not importable # Drop NVRTC if not importable
try: try:
from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy
if not is_nvrtc_available and "nvrtc" in allowed: if not is_nvrtc_available and "nvrtc" in allowed:
allowed = [b for b in allowed if b != "nvrtc"] allowed = [b for b in allowed if b != "nvrtc"]
except Exception: except Exception:
...@@ -89,12 +90,14 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str: ...@@ -89,12 +90,14 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str:
if req not in allowed_all: if req not in allowed_all:
raise ValueError( raise ValueError(
f"Invalid execution backend '{requested}' for target '{_target_kind(target)}'. " f"Invalid execution backend '{requested}' for target '{_target_kind(target)}'. "
f"Allowed: {_format_options(allowed_all)}. Tip: use execution_backend='auto'.") f"Allowed: {_format_options(allowed_all)}. Tip: use execution_backend='auto'."
)
# Promote to availability-aware set for nicer errors (e.g., nvrtc not installed) # Promote to availability-aware set for nicer errors (e.g., nvrtc not installed)
if req not in allowed_avail: if req not in allowed_avail:
raise ValueError( raise ValueError(
f"Execution backend '{requested}' requires extra dependencies and is not available now. " f"Execution backend '{requested}' requires extra dependencies and is not available now. "
f"Try one of: {_format_options(allowed_avail)}.") f"Try one of: {_format_options(allowed_avail)}."
)
return req return req
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Generic, Literal, TypeVar from typing import Any, Callable, Generic, Literal, TypeVar
# Python 3.9 compatibility for ParamSpec # Python 3.9 compatibility for ParamSpec
try: try:
from typing import ParamSpec from typing import ParamSpec
...@@ -14,8 +15,7 @@ import tilelang ...@@ -14,8 +15,7 @@ import tilelang
from tilelang import tvm from tilelang import tvm
from tilelang import env from tilelang import env
from tilelang.engine.param import CompiledArtifact, KernelParam from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, from tilelang.jit.adapter import BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, TVMFFIKernelAdapter, MetalKernelAdapter
TVMFFIKernelAdapter, MetalKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.contrib import nvcc as tl_nvcc from tilelang.contrib import nvcc as tl_nvcc
...@@ -24,8 +24,8 @@ import os ...@@ -24,8 +24,8 @@ import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_P = ParamSpec('_P') _P = ParamSpec("_P")
_T = TypeVar('_T') _T = TypeVar("_T")
class JITKernel(Generic[_P, _T]): class JITKernel(Generic[_P, _T]):
...@@ -41,6 +41,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -41,6 +41,7 @@ class JITKernel(Generic[_P, _T]):
torch_function : Callable torch_function : Callable
The compiled function that can be invoked as a PyTorch-compatible function. The compiled function that can be invoked as a PyTorch-compatible function.
""" """
prim_func: PrimFunc = None prim_func: PrimFunc = None
artifact: CompiledArtifact = None artifact: CompiledArtifact = None
adapter: BaseKernelAdapter = None adapter: BaseKernelAdapter = None
...@@ -111,9 +112,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -111,9 +112,7 @@ class JITKernel(Generic[_P, _T]):
if execution_backend == "cython": if execution_backend == "cython":
from tilelang.contrib.cc import get_cplus_compiler from tilelang.contrib.cc import get_cplus_compiler
assert ( assert get_cplus_compiler() is not None, "Cython backend requires a C++ compiler, please install or use other backends."
get_cplus_compiler() is not None
), "Cython backend requires a C++ compiler, please install or use other backends."
if from_database: if from_database:
return return
...@@ -200,8 +199,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -200,8 +199,7 @@ class JITKernel(Generic[_P, _T]):
""" """
return self.torch_function(*args, **kwds) return self.torch_function(*args, **kwds)
def _compile_and_create_adapter(self, tilelang_func: PrimFunc, def _compile_and_create_adapter(self, tilelang_func: PrimFunc, out_idx: list[int]) -> BaseKernelAdapter:
out_idx: list[int]) -> BaseKernelAdapter:
""" """
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter. Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
...@@ -233,7 +231,8 @@ class JITKernel(Generic[_P, _T]): ...@@ -233,7 +231,8 @@ class JITKernel(Generic[_P, _T]):
target=target, target=target,
target_host=target_host, target_host=target_host,
enable_host_codegen=enable_host_codegen, enable_host_codegen=enable_host_codegen,
enable_device_compile=enable_device_compile) enable_device_compile=enable_device_compile,
)
self.artifact = artifact self.artifact = artifact
...@@ -241,7 +240,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -241,7 +240,7 @@ class JITKernel(Generic[_P, _T]):
if execution_backend == "tvm_ffi": if execution_backend == "tvm_ffi":
# Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack. # Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack.
# But we need to ensure that the runtime is enabled and the runtime module is not None. # But we need to ensure that the runtime is enabled and the runtime module is not None.
assert (artifact.rt_mod is not None), "tvm_ffi backend requires a runtime module." assert artifact.rt_mod is not None, "tvm_ffi backend requires a runtime module."
adapter = TVMFFIKernelAdapter( adapter = TVMFFIKernelAdapter(
params=artifact.params, params=artifact.params,
result_idx=out_idx, result_idx=out_idx,
...@@ -283,6 +282,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -283,6 +282,7 @@ class JITKernel(Generic[_P, _T]):
) )
elif execution_backend == "nvrtc": elif execution_backend == "nvrtc":
from tilelang.jit.adapter import NVRTCKernelAdapter from tilelang.jit.adapter import NVRTCKernelAdapter
adapter = NVRTCKernelAdapter( adapter = NVRTCKernelAdapter(
params=artifact.params, params=artifact.params,
result_idx=out_idx, result_idx=out_idx,
...@@ -315,16 +315,18 @@ class JITKernel(Generic[_P, _T]): ...@@ -315,16 +315,18 @@ class JITKernel(Generic[_P, _T]):
return adapter return adapter
def _create_adapter_from_database(self, def _create_adapter_from_database(
params: list[KernelParam], self,
result_idx: list[int] | int, params: list[KernelParam],
target: str | Target, result_idx: list[int] | int,
func_or_mod: PrimFunc | tvm.runtime.Module, target: str | Target,
host_kernel_source: str, func_or_mod: PrimFunc | tvm.runtime.Module,
device_kernel_source: str, host_kernel_source: str,
kernel_lib_path: str, device_kernel_source: str,
pass_configs: dict[str, Any] | None = None, kernel_lib_path: str,
compile_flags: list[str] | None = None) -> BaseKernelAdapter: pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
) -> BaseKernelAdapter:
target = self.target target = self.target
execution_backend = self.execution_backend execution_backend = self.execution_backend
...@@ -366,6 +368,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -366,6 +368,7 @@ class JITKernel(Generic[_P, _T]):
) )
elif execution_backend == "nvrtc": elif execution_backend == "nvrtc":
from tilelang.jit.adapter import NVRTCKernelAdapter from tilelang.jit.adapter import NVRTCKernelAdapter
adapter = NVRTCKernelAdapter.from_database( adapter = NVRTCKernelAdapter.from_database(
params=params, params=params,
result_idx=result_idx, result_idx=result_idx,
...@@ -402,8 +405,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -402,8 +405,7 @@ class JITKernel(Generic[_P, _T]):
""" """
return cls(func=tilelang_func, **kwargs) return cls(func=tilelang_func, **kwargs)
def get_profiler(self, def get_profiler(self, tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler:
tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler:
""" """
Creates a profiler to benchmark the compiled runtime module. Creates a profiler to benchmark the compiled runtime module.
...@@ -417,8 +419,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -417,8 +419,7 @@ class JITKernel(Generic[_P, _T]):
Profiler Profiler
A Profiler instance for benchmarking the runtime module. A Profiler instance for benchmarking the runtime module.
""" """
return Profiler(self.params, self.out_idx, return Profiler(self.params, self.out_idx, tensor_supply_type).with_default_adapter(self.adapter)
tensor_supply_type).with_default_adapter(self.adapter)
def get_kernel_source(self, kernel_only: bool = True) -> str: def get_kernel_source(self, kernel_only: bool = True) -> str:
""" """
...@@ -507,21 +508,19 @@ class JITKernel(Generic[_P, _T]): ...@@ -507,21 +508,19 @@ class JITKernel(Generic[_P, _T]):
dir_path = os.path.dirname(kernel_path) dir_path = os.path.dirname(kernel_path)
if dir_path: if dir_path:
os.makedirs(dir_path, exist_ok=True) os.makedirs(dir_path, exist_ok=True)
with open(kernel_path, 'w') as f: with open(kernel_path, "w") as f:
f.write(self.get_kernel_source()) f.write(self.get_kernel_source())
if host_path is not None: if host_path is not None:
dir_path = os.path.dirname(host_path) dir_path = os.path.dirname(host_path)
if dir_path: if dir_path:
os.makedirs(dir_path, exist_ok=True) os.makedirs(dir_path, exist_ok=True)
with open(host_path, 'w') as f: with open(host_path, "w") as f:
f.write(self.get_host_source()) f.write(self.get_host_source())
except Exception as e: except Exception as e:
logger.error(f"Failed to export sources: {e}") logger.error(f"Failed to export sources: {e}")
# Backward compatibility alias (deprecated) # Backward compatibility alias (deprecated)
def print_source_code(self, def print_source_code(self, which: Literal["kernel", "host", "both"] = "kernel", file: str | None = None) -> None:
which: Literal["kernel", "host", "both"] = "kernel",
file: str | None = None) -> None:
""" """
Deprecated: use show_source() or export_sources() instead. Deprecated: use show_source() or export_sources() instead.
...@@ -541,16 +540,14 @@ class JITKernel(Generic[_P, _T]): ...@@ -541,16 +540,14 @@ class JITKernel(Generic[_P, _T]):
>>> # Old API (still works but deprecated) >>> # Old API (still works but deprecated)
>>> jit_kernel.print_source_code(file="/tmp/kernel.cu") >>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
""" """
logger.warning( logger.warning("print_source_code is deprecated; use show_source() or export_sources() instead.")
"print_source_code is deprecated; use show_source() or export_sources() instead.")
if file is not None: if file is not None:
# Historical behavior wrote only kernel source when file provided # Historical behavior wrote only kernel source when file provided
self.export_sources(kernel_path=file) self.export_sources(kernel_path=file)
else: else:
self.show_source(which=which) self.show_source(which=which)
def update_tuner_result(self, latency: float, config: dict[str, Any], def update_tuner_result(self, latency: float, config: dict[str, Any], ref_latency: float) -> JITKernel:
ref_latency: float) -> JITKernel:
""" """
Updates the tuning results for this kernel. Updates the tuning results for this kernel.
...@@ -651,8 +648,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -651,8 +648,7 @@ class JITKernel(Generic[_P, _T]):
verbose = self.verbose verbose = self.verbose
# Ensure target is set so nvcc picks correct arch via Target.current() # Ensure target is set so nvcc picks correct arch via Target.current()
with self.target: with self.target:
return tl_nvcc.get_ptx_from_source( return tl_nvcc.get_ptx_from_source(code, compile_flags=self.compile_flags, verbose=verbose)
code, compile_flags=self.compile_flags, verbose=verbose)
def show_ptx(self) -> None: def show_ptx(self) -> None:
""" """
...@@ -714,8 +710,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -714,8 +710,7 @@ class JITKernel(Generic[_P, _T]):
if verbose is None: if verbose is None:
verbose = self.verbose verbose = self.verbose
with self.target: with self.target:
return tl_nvcc.get_sass_from_source( return tl_nvcc.get_sass_from_source(code, compile_flags=self.compile_flags, verbose=verbose)
code, compile_flags=self.compile_flags, verbose=verbose)
def show_sass(self) -> None: def show_sass(self) -> None:
""" """
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
# from .parser import * # from .parser import *
...@@ -102,7 +103,10 @@ from .utils import index_to_coordinates # noqa: F401 ...@@ -102,7 +103,10 @@ from .utils import index_to_coordinates # noqa: F401
from .symbolics import dynamic, symbolic # noqa: F401 from .symbolics import dynamic, symbolic # noqa: F401
from .annotations import ( # noqa: F401 from .annotations import ( # noqa: F401
use_swizzle, annotate_layout, annotate_safe_value, annotate_l2_hit_ratio, use_swizzle,
annotate_layout,
annotate_safe_value,
annotate_l2_hit_ratio,
) )
......
...@@ -13,8 +13,10 @@ Available allocation functions: ...@@ -13,8 +13,10 @@ Available allocation functions:
Each function takes shape and dtype parameters and returns a TVM buffer object Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope. with the appropriate memory scope.
""" """
from __future__ import annotations from __future__ import annotations
from typing import TypeVar, overload, Literal, Callable from typing import TypeVar, overload, Literal, Callable
# Python 3.9 compatibility for advanced typing features (PEP 646) # Python 3.9 compatibility for advanced typing features (PEP 646)
try: try:
from typing import TypeVarTuple, Unpack # type: ignore[attr-defined] from typing import TypeVarTuple, Unpack # type: ignore[attr-defined]
...@@ -30,13 +32,11 @@ from .v2.dtypes import dtype as tl_dtype ...@@ -30,13 +32,11 @@ from .v2.dtypes import dtype as tl_dtype
from .v2.builder import OutTensor from .v2.builder import OutTensor
from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer
_Shapes = TypeVarTuple('_Shapes') _Shapes = TypeVarTuple("_Shapes")
_DType = TypeVar('_DType') _DType = TypeVar("_DType")
def alloc_shared(shape: tuple[Unpack[_Shapes]], def alloc_shared(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]:
dtype: _DType,
scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a shared memory buffer for inter-thread communication. """Allocate a shared memory buffer for inter-thread communication.
Args: Args:
...@@ -54,9 +54,7 @@ def alloc_shared(shape: tuple[Unpack[_Shapes]], ...@@ -54,9 +54,7 @@ def alloc_shared(shape: tuple[Unpack[_Shapes]],
return T.alloc_buffer(shape, dtype, scope=scope) return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_local(shape: tuple[Unpack[_Shapes]], def alloc_local(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="local") -> LocalBuffer[Callable[[Unpack[_Shapes]]], _DType]:
dtype: _DType,
scope="local") -> LocalBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a local memory buffer for thread-private storage. """Allocate a local memory buffer for thread-private storage.
Args: Args:
...@@ -70,9 +68,9 @@ def alloc_local(shape: tuple[Unpack[_Shapes]], ...@@ -70,9 +68,9 @@ def alloc_local(shape: tuple[Unpack[_Shapes]],
return T.alloc_buffer(shape, dtype, scope=scope) return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_fragment(shape: tuple[Unpack[_Shapes]], def alloc_fragment(
dtype: _DType, shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="local.fragment"
scope="local.fragment") -> FragmentBuffer[Callable[[Unpack[_Shapes]]], _DType]: ) -> FragmentBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a fragment memory buffer for specialized operations. """Allocate a fragment memory buffer for specialized operations.
Args: Args:
...@@ -87,16 +85,11 @@ def alloc_fragment(shape: tuple[Unpack[_Shapes]], ...@@ -87,16 +85,11 @@ def alloc_fragment(shape: tuple[Unpack[_Shapes]],
@overload @overload
def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = 'local.var') -> Buffer: def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = "local.var") -> Buffer: ...
...
@overload @overload
def alloc_var(dtype: str, def alloc_var(dtype: str, scope: str = "local.var", *, init: PrimExpr | int | float | None = None) -> Buffer: ...
scope: str = 'local.var',
*,
init: PrimExpr | int | float | None = None) -> Buffer:
...
def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
...@@ -142,8 +135,7 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): ...@@ -142,8 +135,7 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
raise TypeError("Scope must be provided as a string in alloc_var.") raise TypeError("Scope must be provided as a string in alloc_var.")
parsed_scope = parsed_scope_arg parsed_scope = parsed_scope_arg
elif len(args) > 2: elif len(args) > 2:
raise TypeError( raise TypeError(f"alloc_var expected at most 3 positional arguments but got {len(args) + 1}.")
f"alloc_var expected at most 3 positional arguments but got {len(args) + 1}.")
if not isinstance(parsed_scope, str): if not isinstance(parsed_scope, str):
raise TypeError("Scope must be a string in alloc_var.") raise TypeError("Scope must be a string in alloc_var.")
...@@ -274,13 +266,10 @@ def alloc_tcgen05_instr_desc(dtype: str = "uint32"): ...@@ -274,13 +266,10 @@ def alloc_tcgen05_instr_desc(dtype: str = "uint32"):
@overload @overload
def empty(shape: tuple[Unpack[_Shapes]], def empty(shape: tuple[Unpack[_Shapes]], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ...
dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
...
def empty(*shape: Unpack[_Shapes], def empty(*shape: Unpack[_Shapes], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
if len(shape) == 1 and isinstance(shape[0], (tuple, list)): if len(shape) == 1 and isinstance(shape[0], (tuple, list)):
return OutTensor(shape[0], dtype) return OutTensor(shape[0], dtype)
elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str): elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str):
...@@ -288,4 +277,4 @@ def empty(*shape: Unpack[_Shapes], ...@@ -288,4 +277,4 @@ def empty(*shape: Unpack[_Shapes],
elif all([isinstance(x, (int, PrimExpr)) for x in shape]): elif all([isinstance(x, (int, PrimExpr)) for x in shape]):
return OutTensor(shape, dtype) return OutTensor(shape, dtype)
else: else:
raise RuntimeError(f'Invalid shape {shape}') raise RuntimeError(f"Invalid shape {shape}")
"""Annotation helpers exposed on the TileLang language surface.""" """Annotation helpers exposed on the TileLang language surface."""
from typing import Callable from typing import Callable
from tilelang.layout import Layout from tilelang.layout import Layout
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# This file is modified from the original version, # This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/). # which is part of the TVM project (https://tvm.apache.org/).
"""Package tvm.script.ir_builder.tir""" """Package tvm.script.ir_builder.tir"""
from .ir import * # noqa: F401 from .ir import * # noqa: F401
from .ir import boolean as bool # noqa: F401 from .ir import boolean as bool # noqa: F401
from .ir import buffer as Buffer # noqa: F401 from .ir import buffer as Buffer # noqa: F401
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment