Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -3,6 +3,7 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM.
"""
from __future__ import annotations
from dataclasses import dataclass
......@@ -39,17 +40,16 @@ from tqdm.auto import tqdm
logger = getLogger(__name__)
_P = ParamSpec('_P')
_KP = ParamSpec('_KP')
_T = TypeVar('_T')
_Ret = TypeVar('_Ret')
_P = ParamSpec("_P")
_KP = ParamSpec("_KP")
_T = TypeVar("_T")
_Ret = TypeVar("_Ret")
def compile(
func: PrimFunc[_KP, _T] = None,
out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
target: str | Target = "auto",
target_host: str | Target | None = None,
verbose: bool = False,
......@@ -83,11 +83,9 @@ def compile(
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
if hasattr(func, 'out_idx_override'):
if hasattr(func, "out_idx_override"):
if func.out_idx_override is not None and out_idx is not None:
raise ValueError(
"Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors"
)
raise ValueError("Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors")
out_idx = func.out_idx_override or out_idx
# This path is not a performance critical path, so we can afford to convert the target.
......@@ -96,6 +94,7 @@ def compile(
# Resolve execution backend (handles aliases, auto, validation per target)
requested_backend = execution_backend
from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target
execution_backend = resolve_execution_backend(requested_backend, target)
if verbose:
allowed_now = allowed_backends_for_target(target, include_unavailable=False)
......@@ -119,17 +118,18 @@ def compile(
)
def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
def par_compile(
funcs: Iterable[PrimFunc[_KP, _T]],
out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
target: str | Target = "auto",
target_host: str | Target | None = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | str | None = None,
num_workers: int = None,
ignore_error: bool = False) -> list[JITKernel[_KP, _T]]:
ignore_error: bool = False,
) -> list[JITKernel[_KP, _T]]:
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
......@@ -151,7 +151,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
Additional keyword arguments to pass to the Compiler PassContext.
Refer to `tilelang.transform.PassConfigKey` for supported options.
"""
with concurrent.futures.ThreadPoolExecutor(num_workers, 'tl-par-comp') as executor:
with concurrent.futures.ThreadPoolExecutor(num_workers, "tl-par-comp") as executor:
futures = []
future_map = {}
for i, func in enumerate(funcs):
......@@ -189,7 +189,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
@dataclass
class JITImpl(Generic[_P, _KP, _T, _Ret]):
'''
"""
Detailed Just-In-Time wrapper for TileLang programs.
This dataclass encapsulates the configuration and runtime helpers used by the
......@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
PrimFunc and the resulting set is compiled in parallel via the
module-level `par_compile` helper. Returns a list of JITKernel objects
in the same order as the provided configs.
'''
"""
out_idx: list[int] | int | None
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"]
......@@ -302,10 +302,9 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}"
return tir
def par_compile(self,
configs: Iterable[dict[str, Any] | tuple[str, Any]],
num_workers: int = None,
ignore_error: bool = False) -> list[JITKernel[_KP, _T]]:
def par_compile(
self, configs: Iterable[dict[str, Any] | tuple[str, Any]], num_workers: int = None, ignore_error: bool = False
) -> list[JITKernel[_KP, _T]]:
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
......@@ -328,7 +327,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
"""
configs = list(configs)
funcs = []
for cfg in tqdm(configs, desc='Elaborating'):
for cfg in tqdm(configs, desc="Elaborating"):
if isinstance(cfg, tuple):
funcs.append(self.get_tir(*cfg))
elif isinstance(cfg, dict):
......@@ -345,7 +344,8 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
pass_configs=self.pass_configs,
compile_flags=self.compile_flags,
num_workers=num_workers,
ignore_error=ignore_error)
ignore_error=ignore_error,
)
def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
func = self.get_tir(*args, **kwargs)
......@@ -362,25 +362,25 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
if self.debug_root_path:
if isinstance(self.func, PrimFunc):
func_name = self.func.attrs['global_symbol']
func_name = self.func.attrs["global_symbol"]
else:
func_name = getattr(self.func, '__name__', 'jit_kernel')
kernel_file = f'tilelang_jit_kernel_{func_name}.c'
program_file = f'tilelang_jit_program_{func_name}.py'
func_name = getattr(self.func, "__name__", "jit_kernel")
kernel_file = f"tilelang_jit_kernel_{func_name}.c"
program_file = f"tilelang_jit_program_{func_name}.py"
makedirs(self.debug_root_path, exist_ok=True)
with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
with open(path.join(self.debug_root_path, kernel_file), "w") as f:
print(kernel_result.get_kernel_source(), file=f)
with open(path.join(self.debug_root_path, program_file), 'w') as f:
with open(path.join(self.debug_root_path, program_file), "w") as f:
print(func.script(), file=f)
return kernel_result
def parse_cache_key(self, *args: _P.args, **kwargs: _P.kwargs):
if isinstance(self.func, PrimFuncCreater):
tune_params = kwargs.pop('__tune_params', {})
tune_params = kwargs.pop("__tune_params", {})
return self.func.func_annot.parse_key(*args, **kwargs, **tune_params)
else:
tune_params = kwargs.pop('__tune_params', {})
tune_params = kwargs.pop("__tune_params", {})
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
......@@ -389,34 +389,31 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
def convert_kernel_args(self, *args: _P.args, **kwargs: _P.kwargs):
if isinstance(self.func, PrimFuncCreater):
tune_params = kwargs.pop('__tune_params', {})
tune_params = kwargs.pop("__tune_params", {})
return self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params)
else:
raise NotImplementedError(
"convert_arg_to_kernel_args is only implemented for PrimFuncCreater.")
raise NotImplementedError("convert_arg_to_kernel_args is only implemented for PrimFuncCreater.")
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
# Separate out the tuning parameters from the user's kwargs
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
return_compile_arguments = kwargs.pop('__return_compile_arguments', False)
return_compile_arguments = kwargs.pop("__return_compile_arguments", False)
if return_compile_arguments:
logger.warning(
"`__return_compile_arguments` is deprecated and will be removed in future versions."
)
logger.warning("`__return_compile_arguments` is deprecated and will be removed in future versions.")
compile_args = {
'out_idx': self.out_idx,
'execution_backend': self.execution_backend,
'target': self.target,
'target_host': self.target_host,
'verbose': self.verbose,
'pass_configs': self.pass_configs,
'compile_flags': self.compile_flags,
"out_idx": self.out_idx,
"execution_backend": self.execution_backend,
"target": self.target,
"target_host": self.target_host,
"verbose": self.verbose,
"pass_configs": self.pass_configs,
"compile_flags": self.compile_flags,
}
return compile_args
key = self.parse_cache_key(*args, **kwargs)
tune_params = kwargs.pop('__tune_params', {})
tune_params = kwargs.pop("__tune_params", {})
kernel = self._kernel_cache.get(key, None)
if kernel is None:
......@@ -434,8 +431,7 @@ ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvr
@overload
def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]:
...
def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]: ...
@overload
......@@ -448,9 +444,8 @@ def jit(
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None
) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]:
...
compile_flags: list[str] | str | None = None,
) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]: ...
def jit( # This is the new public interface
......@@ -463,7 +458,8 @@ def jit( # This is the new public interface
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None):
compile_flags: list[str] | str | None = None,
):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
......@@ -516,7 +512,8 @@ def jit( # This is the new public interface
compile_flags=compile_flags,
func_source=inspect.getsource(orig_func),
signature=inspect.signature(orig_func),
lazy_jit=False)
lazy_jit=False,
)
if func is not None:
return decorator(func)
......@@ -525,8 +522,7 @@ def jit( # This is the new public interface
@overload
def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]:
...
def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: ...
@overload
......@@ -539,9 +535,8 @@ def lazy_jit(
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None
) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]:
...
compile_flags: list[str] | str | None = None,
) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: ...
def lazy_jit(
......@@ -555,7 +550,6 @@ def lazy_jit(
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None,
):
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
......@@ -567,7 +561,8 @@ def lazy_jit(
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path,
compile_flags=compile_flags)
compile_flags=compile_flags,
)
def decorator(func: Callable[_P, _T]):
pf: PrimFunc[_P, _T] | PrimFuncCreater[_P, _T] = prim_func(func, generator=True)
......@@ -576,10 +571,7 @@ def lazy_jit(
# return compile(pf, **compile_args)
# else:
return JITImpl(
func=pf,
**compile_args,
func_source=inspect.getsource(pf.orig_func),
signature=inspect.signature(pf.orig_func),
lazy_jit=True)
func=pf, **compile_args, func_source=inspect.getsource(pf.orig_func), signature=inspect.signature(pf.orig_func), lazy_jit=True
)
return decorator(func) if func is not None else decorator
"""The profiler and convert to torch utils"""
from __future__ import annotations
from abc import ABC, abstractmethod
......@@ -8,7 +9,6 @@ import torch
class BaseKernelAdapter(ABC):
func: Callable | None = None
def __init__(self, mod, params: list[KernelParam], result_idx: list[int]) -> None:
......@@ -24,18 +24,14 @@ class BaseKernelAdapter(ABC):
result_idx = []
elif isinstance(result_idx, int):
if result_idx > len(params) or result_idx < -len(params):
raise ValueError(
f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}"
)
raise ValueError(f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}")
if result_idx < 0:
result_idx = len(params) + result_idx
result_idx = [result_idx]
elif isinstance(result_idx, list):
for i, idx in enumerate(result_idx):
if idx >= len(params) or idx < -len(params):
raise ValueError(
f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}"
)
raise ValueError(f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}")
if idx < 0:
result_idx[i] = len(params) + idx
else:
......
"""The profiler and convert to torch utils"""
from __future__ import annotations
import torch
from ..base import BaseKernelAdapter
......@@ -41,7 +42,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
param_dtypes: list[torch.dtype] | None = None # Cache for parameter dtypes
param_shapes: list[list] | None = None # Cache for parameter shapes
def __init__(self,
def __init__(
self,
params: list[TensorType],
result_idx: list[int],
target: str,
......@@ -52,7 +54,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
device_kernel_source: str | None = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
compile_flags: list[str] | None = None,
):
"""Initialize the adapter with the given TIR function or module.
Args:
......@@ -109,7 +112,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self._post_init()
@classmethod
def from_database(cls,
def from_database(
cls,
params: list[TensorType],
result_idx: list[int],
target: str,
......@@ -119,7 +123,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
kernel_lib_path: str,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
compile_flags: list[str] | None = None,
):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
......@@ -175,15 +180,13 @@ class CtypesKernelAdapter(BaseKernelAdapter):
if param in buffer_map:
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
(shape not in params)):
if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params):
dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and
(stride not in params)):
if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params):
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map
......@@ -192,9 +195,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
"""
ctypes_args = [
ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args
]
ctypes_args = [ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args]
ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args)
......@@ -288,7 +289,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
@property
def is_dynamic(self):
"""Indicates whether the kernel handles dynamic shapes."""
return (self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0)
return self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0
def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel."""
......
"""The profiler and convert to torch utils"""
from __future__ import annotations
import ctypes
import logging
......@@ -70,7 +71,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
# Pass configs for the compiler
pass_configs: dict[str, Any] | None = None
def __init__(self,
def __init__(
self,
params: list[KernelParam],
result_idx: list[int],
target: str | Target,
......@@ -80,7 +82,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
device_kernel_source: str | None = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
compile_flags: list[str] | None = None,
):
"""Initialize the adapter with the given TIR function or module.
Args:
......@@ -130,7 +133,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.lib.get_last_error.restype = ctypes.c_char_p
result = self.lib.init()
if result != 0:
error_msg = self.lib.get_last_error().decode('utf-8')
error_msg = self.lib.get_last_error().decode("utf-8")
error_msg += f"\n{self.lib_code}"
raise RuntimeError(f"Initialization failed: {error_msg}")
......@@ -145,7 +148,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
self._post_init()
@classmethod
def from_database(cls,
def from_database(
cls,
params: list[TensorType],
result_idx: list[int],
target: str,
......@@ -155,7 +159,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
kernel_lib_path: str,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
compile_flags: list[str] | None = None,
):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
......@@ -190,11 +195,10 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.lib.get_last_error.restype = ctypes.c_char_p
result = adapter.lib.init()
if result != 0:
error_msg = adapter.lib.get_last_error().decode('utf-8')
error_msg = adapter.lib.get_last_error().decode("utf-8")
raise RuntimeError(f"Initialization failed: {error_msg}")
adapter.cython_wrapper = CythonKernelWrapper(adapter.result_idx, adapter.params,
adapter.lib)
adapter.cython_wrapper = CythonKernelWrapper(adapter.result_idx, adapter.params, adapter.lib)
adapter.cython_wrapper.set_dynamic_symbolic_map(adapter.dynamic_symbolic_map)
adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map)
adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map)
......@@ -221,15 +225,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
if param in buffer_map:
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
(shape not in params)):
if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params):
dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and
(stride not in params)):
if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params):
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map
......@@ -259,14 +261,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
params = func.params
ptr_map = {}
for i, param in enumerate(params):
if param.dtype == 'handle':
if param.dtype == "handle":
ptr_map[i] = param.name
return ptr_map
def _process_static_buffer_infos(self) -> \
tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]],
dict[tir.Var, tuple[int, list[tuple[int, int]]]],
list[tuple[tir.Var]]]:
def _process_static_buffer_infos(
self,
) -> tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]], dict[tir.Var, tuple[int, list[tuple[int, int]]]], list[tuple[tir.Var]]]:
"""Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes.
......@@ -332,9 +333,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
"""
ctypes_args = [
ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args
]
ctypes_args = [ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args]
ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args)
......@@ -349,9 +348,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
skip_tensor_validation: Whether to skip tensor attributes validation which
includes shape, dtype, device, etc.
"""
return self.cython_wrapper.forward([*args],
stream=stream,
skip_tensor_validation=skip_tensor_validation)
return self.cython_wrapper.forward([*args], stream=stream, skip_tensor_validation=skip_tensor_validation)
return lambda_forward
......
......@@ -55,6 +55,7 @@ class LibraryGenerator:
verbose = self.verbose
if is_cuda_target(target):
from tilelang.env import CUTLASS_INCLUDE_DIR
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) # noqa: SIM115
target_arch = get_target_arch(get_target_compute_version(target))
libpath = src.name.replace(".cu", ".so")
......@@ -65,15 +66,12 @@ class LibraryGenerator:
"TL_ENABLE_FAST_MATH",
"0.1.7",
)
enable_fast_math = not self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH,
True)
enable_fast_math = not self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, True)
else:
enable_fast_math = self.pass_configs.get(PassConfigKey.TL_ENABLE_FAST_MATH, False)
ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL,
None)
verbose_ptxas_output = self.pass_configs.get(
PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False)
ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, None)
verbose_ptxas_output = self.pass_configs.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False)
command = [
get_nvcc_compiler(),
......@@ -102,6 +100,7 @@ class LibraryGenerator:
elif is_hip_target(target):
from tilelang.env import COMPOSABLE_KERNEL_INCLUDE_DIR
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115
libpath = src.name.replace(".cpp", ".so")
rocm_path = find_rocm_path()
......@@ -119,6 +118,7 @@ class LibraryGenerator:
]
elif is_cpu_target(target):
from tilelang.contrib.cc import get_cplus_compiler
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115
libpath = src.name.replace(".cpp", ".so")
......@@ -134,9 +134,7 @@ class LibraryGenerator:
]
if self.compile_flags:
command += [
item for flag in self.compile_flags for item in flag.split() if item not in command
]
command += [item for flag in self.compile_flags for item in flag.split() if item not in command]
command += ["-o", libpath]
......@@ -151,8 +149,7 @@ class LibraryGenerator:
raise RuntimeError(f"Compile kernel failed because of {e}") from e
if ret.returncode != 0:
raise RuntimeError(f"Compilation Failed! {command}"
f"\n {self.lib_code}")
raise RuntimeError(f"Compilation Failed! {command}\n {self.lib_code}")
self.srcpath = src.name
self.libpath = libpath
......
......@@ -5,22 +5,22 @@ This module provides runtime compilation support using NVIDIA's NVRTC API.
import logging
__all__ = [
'NVRTCKernelAdapter', 'TLNVRTCSourceWrapper', 'NVRTCLibraryGenerator', 'is_nvrtc_available',
'check_nvrtc_available'
]
__all__ = ["NVRTCKernelAdapter", "TLNVRTCSourceWrapper", "NVRTCLibraryGenerator", "is_nvrtc_available", "check_nvrtc_available"]
logger = logging.getLogger(__name__)
# Check if cuda-python is available
is_nvrtc_available = False
NVRTC_UNAVAILABLE_MESSAGE = ("cuda-python is not available, NVRTC backend cannot be used. "
NVRTC_UNAVAILABLE_MESSAGE = (
"cuda-python is not available, NVRTC backend cannot be used. "
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the NVRTC backend.")
"if you want to use the NVRTC backend."
)
try:
import cuda.bindings.driver as cuda # noqa: F401
import cuda.bindings.nvrtc as nvrtc # noqa: F401
is_nvrtc_available = True
except ImportError as e:
logger.debug(f"cuda-python import failed: {e}")
......
......@@ -27,7 +27,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
pymodule = None
kernels = {}
def __init__(self,
def __init__(
self,
params: list[KernelParam],
result_idx: list[int],
target: str | Target,
......@@ -37,8 +38,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
device_kernel_source: str | None = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
compile_flags: list[str] | None = None,
):
check_nvrtc_available()
self.params = params
......@@ -92,7 +93,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self._post_init()
@classmethod
def from_database(cls,
def from_database(
cls,
params: list[KernelParam],
result_idx: list[int],
target: str,
......@@ -102,7 +104,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
kernel_lib_path: str,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
compile_flags: list[str] | None = None,
):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
......@@ -183,8 +186,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
return self.host_func
def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel.
"""
"""Low-level function to call the compiled CUDA kernel."""
return self.pymodule.call(self.kernels, *args, stream=stream)
def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None):
......
......@@ -13,6 +13,7 @@ Key responsibilities:
- Load compiled cubin and extract kernel handles
- Manage library lifecycle (load/unload)
"""
from __future__ import annotations
import importlib
import logging
......@@ -56,6 +57,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
culib: CUDA library handle (CUlibrary)
pymodule: Imported Python module containing call() function
"""
host_func: str | None = None
culib: cuda.CUlibrary | None = None
pymodule: ModuleType | None = None
......@@ -131,10 +133,10 @@ class NVRTCLibraryGenerator(LibraryGenerator):
ctx = cuda.cuCtxGetCurrent()[1]
if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS:
import torch
torch.cuda.synchronize()
result, self.culib = cuda.cuLibraryLoadFromFile(
bytes(lib_path, "utf-8"), [], [], 0, [], [], 0)
result, self.culib = cuda.cuLibraryLoadFromFile(bytes(lib_path, "utf-8"), [], [], 0, [], [], 0)
if result != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to load library: {lib_path}, error: {result}")
......@@ -164,7 +166,8 @@ class NVRTCLibraryGenerator(LibraryGenerator):
target = self.target
verbose = self.verbose
if is_cuda_target(target):
from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH)
from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
libpath = src.name.replace(".cu", ".cubin")
......@@ -195,13 +198,9 @@ class NVRTCLibraryGenerator(LibraryGenerator):
f"-D__CUDACC_VER_MAJOR__={__CUDACC_VER_MAJOR__}",
]
if self.compile_flags:
options += [
item for flag in self.compile_flags for item in flag.split()
if item not in options
]
options += [item for flag in self.compile_flags for item in flag.split() if item not in options]
cubin_bytes = compile_cuda(
self.lib_code, target_format="cubin", options=options, verbose=verbose)
cubin_bytes = compile_cuda(self.lib_code, target_format="cubin", options=options, verbose=verbose)
with open(libpath, "wb") as f:
f.write(cubin_bytes)
......@@ -212,8 +211,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
self.libpath = libpath
self.pypath = src.name.replace(".cu", ".py")
if self.host_func is None:
raise RuntimeError(
"Host function is not set, please call update_host_func() first.")
raise RuntimeError("Host function is not set, please call update_host_func() first.")
with open(self.pypath, "w") as f:
f.write(self.host_func)
else:
......
......@@ -12,6 +12,7 @@ Key design:
- Dict-based deduplication ensures TMA descriptors created only once
- Generates pure Python using cuda.bindings.driver for zero C++ dependency
"""
from __future__ import annotations
from typing import Any, ClassVar
......@@ -21,8 +22,7 @@ from tvm.tir.stmt_functor import post_order_visit
from tilelang import tvm as tvm
from tilelang.jit.adapter.wrapper import TLCUDASourceWrapper
from tilelang.jit.adapter.utils import (match_declare_kernel, pythonic_expr,
parse_function_call_args, parse_tma_descriptor_args)
from tilelang.jit.adapter.utils import match_declare_kernel, pythonic_expr, parse_function_call_args, parse_tma_descriptor_args
PREDEF_HOST_FUNC_PY = """
from cuda.bindings.driver import (
......@@ -235,13 +235,15 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
_generated_host_func: str | None = None
def __init__(self,
def __init__(
self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
pass_configs: dict[str, Any] | None = None,
):
"""Initialize NVRTC wrapper with compiled IR modules.
Args:
......@@ -303,15 +305,16 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
for param in self.prim_func.params:
if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param]
function_args.append({
function_args.append(
{
"name": buffer.data.name,
"type": "ctypes.c_void_p",
})
}
)
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]:
......@@ -359,9 +362,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
return (f"{name}.data_ptr()", arg_type)
return (name, arg_type)
call_args = parse_function_call_args(declaration, function_args, function_params,
desc_name_map, desc_name_var_map,
transform_nvrtc_arg)
call_args = parse_function_call_args(
declaration, function_args, function_params, desc_name_map, desc_name_var_map, transform_nvrtc_arg
)
for arg_name, arg_type in call_args:
if arg_type == "ctypes.c_void_p":
......@@ -369,26 +372,28 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
break
# Store kernel info for second pass
kernel_info_list.append({
'function_name': function_name,
'block_info': block_info,
'grid_info': grid_info,
'dynamic_smem_buf': dynamic_smem_buf,
'call_args': call_args,
'device_index': device_index,
})
kernel_info_list.append(
{
"function_name": function_name,
"block_info": block_info,
"grid_info": grid_info,
"dynamic_smem_buf": dynamic_smem_buf,
"call_args": call_args,
"device_index": device_index,
}
)
# Generate TMA descriptor initialization code once for all kernels
kernel_launch_code += self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map)
# Second pass: generate kernel launch code for each kernel
for kernel_info in kernel_info_list:
function_name = kernel_info['function_name']
block_info = kernel_info['block_info']
grid_info = kernel_info['grid_info']
dynamic_smem_buf = kernel_info['dynamic_smem_buf']
call_args = kernel_info['call_args']
device_index = kernel_info['device_index']
function_name = kernel_info["function_name"]
block_info = kernel_info["block_info"]
grid_info = kernel_info["grid_info"]
dynamic_smem_buf = kernel_info["dynamic_smem_buf"]
call_args = kernel_info["call_args"]
device_index = kernel_info["device_index"]
arg_names = ", ".join([arg[0] for arg in call_args])
arg_types = ", ".join([arg[1] for arg in call_args])
......@@ -399,23 +404,26 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
kernel_launch_code += init_l2_persistent_map
# Generate kernel launch code
kernel_launch_code += KERNEL_LAUNCH_FUNC_PY.format(function_name,
kernel_launch_code += KERNEL_LAUNCH_FUNC_PY.format(
function_name,
self._pythonic_expr(grid_info[0]),
self._pythonic_expr(grid_info[1]),
self._pythonic_expr(grid_info[2]),
self._pythonic_expr(block_info[0]),
self._pythonic_expr(block_info[1]),
self._pythonic_expr(block_info[2]),
smem_str, arg_names, arg_types,
device_index)
smem_str,
arg_names,
arg_types,
device_index,
)
# Reset L2 persistent map after all kernel execution
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE_PY
# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC_PY.format(
repr(list(function_informations.keys())), def_args, kernel_launch_code)
host_func = PREDEF_HOST_FUNC_PY.format(repr(list(function_informations.keys())), def_args, kernel_launch_code)
return host_func
def generate_l2_persistent_map(self, function_name: str) -> str:
......@@ -434,23 +442,21 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
if function_name not in self.l2_persistent_map:
return ""
init_l2_persistent_map = ""
for buffer_name, (hit_ratio,
size_in_bytes) in self.l2_persistent_map[function_name].items():
for buffer_name, (hit_ratio, size_in_bytes) in self.l2_persistent_map[function_name].items():
# Get persisting_l2_cache_max_size
from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size
persisting_l2_cache_max_size = get_persisting_l2_cache_max_size()
try:
num_bytes = min(size_in_bytes, persisting_l2_cache_max_size)
except TypeError:
# as size_in_bytes may be a symbolic expression
num_bytes = persisting_l2_cache_max_size
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC_PY.format(
buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes))
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC_PY.format(buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes))
return init_l2_persistent_map
def generate_tma_descriptor_args(self, desc_name_map: dict[str, str],
desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
"""Generate Python code to initialize TMA descriptors.
TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects
......@@ -470,28 +476,43 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
return tma_descriptor_init
# Parse TMA descriptor arguments using the common utility
parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map,
desc_name_var_map, self._pythonic_expr)
parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr)
# Generate Python code from parsed parameters
for params in parsed_params:
if not params.is_img2col:
tma_descriptor_init += TMA_DESC_INIT_FUNC_PY.format(
params.handle_name, params.dtype, params.tensor_rank, params.global_address,
params.handle_name,
params.dtype,
params.tensor_rank,
params.global_address,
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)),
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)),
", ".join(map(lambda x: f"cuuint32_t({x})", params.box_dim)),
", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)),
params.interleave, params.swizzle, params.l2_promotion, params.oob_fill)
params.interleave,
params.swizzle,
params.l2_promotion,
params.oob_fill,
)
else:
tma_descriptor_init += TMA_IM2COL_DESC_INIT_FUNC_PY.format(
params.handle_name, params.dtype, params.tensor_rank, params.global_address,
params.handle_name,
params.dtype,
params.tensor_rank,
params.global_address,
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)),
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)),
", ".join(map(lambda x: f"cuuint32_t({x})",
params.element_strides)), ", ".join(params.lower_corner),
", ".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel,
params.interleave, params.swizzle, params.l2_promotion, params.oob_fill)
", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)),
", ".join(params.lower_corner),
", ".join(params.upper_corner),
params.smem_box_channel,
params.smem_box_pixel,
params.interleave,
params.swizzle,
params.l2_promotion,
params.oob_fill,
)
return tma_descriptor_init
......@@ -527,17 +548,14 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
nonlocal function_params
if isinstance(node, tvm.tir.Call):
if not (hasattr(node, "op") and
node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
return
args = node.args
if not args or args[0] != fn:
return
if len(args) < 1 + param_cnt:
raise AssertionError(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params = args[1:1 + param_cnt]
raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters")
function_params = args[1 : 1 + param_cnt]
post_order_visit(self.host_func.body, visitor)
assert function_params is not None, "function_params should not be None"
......
from .metal import MetalKernelAdapter
__all__ = ['MetalKernelAdapter']
__all__ = ["MetalKernelAdapter"]
......@@ -12,7 +12,6 @@ from tilelang.engine.param import KernelParam
class MetalKernelAdapter(BaseKernelAdapter):
def __init__(
self,
params: list[KernelParam],
......@@ -28,10 +27,10 @@ class MetalKernelAdapter(BaseKernelAdapter):
):
self.kernel_global_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc):
func_name = func_or_mod.attrs['global_symbol']
func_name = func_or_mod.attrs["global_symbol"]
else:
func_name = func_or_mod.__name__
self.kernel_name = func_name + '_kernel'
self.kernel_name = func_name + "_kernel"
self.verbose = verbose
self.block_info = [1, 1, 1]
......@@ -39,7 +38,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
for var, func in device_mod.functions.items():
assert var.name_hint == self.kernel_name
thread_extent = func.attrs['thread_extent']
thread_extent = func.attrs["thread_extent"]
for tag, extent in thread_extent.items():
if "threadIdx" in tag:
self.block_info["xyz".index(tag[-1])] = extent
......@@ -47,7 +46,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
self.grid_info["xyz".index(tag[-1])] = extent
break
else:
raise AssertionError(f'no kernel with name {func_name}')
raise AssertionError(f"no kernel with name {func_name}")
# print(self.block_info, self.grid_info)
super().__init__(func_or_mod, result_idx=result_idx, params=params)
......@@ -55,15 +54,12 @@ class MetalKernelAdapter(BaseKernelAdapter):
_kernel = None
def _convert_torch_func(self) -> Callable:
if self._kernel is None:
_kernel = getattr(torch.mps.compile_shader(self.kernel_global_source), self.kernel_name)
_threads = [x * y for (x, y) in zip(self.block_info, self.grid_info)]
@wraps(_kernel)
def launcher(*args: torch.Tensor):
return _kernel(
*args,
threads=_threads,
......
......@@ -5,6 +5,7 @@ via light-weight callables so that, when the wrapped function is invoked,
the execution observes the same stream context as the active Torch code.
On non-CUDA builds, the stream/device fall back to 0/CPU semantics.
"""
from __future__ import annotations
from typing import Callable, Any
......@@ -31,6 +32,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
- The stream pointer returned is a raw CUDA stream handle compatible with
TVM's device API; on CPU or when CUDA is unavailable, we return 0.
"""
# Class attributes to store compiled kernel information
target: str | Target = "cuda"
ir_module: tvm.IRModule | None = None
......@@ -51,7 +53,8 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map: dict[tir.Var, tuple[int, int, int]] | None = None
# Stream/device functors are inherited from BaseKernelAdapter
def __init__(self,
def __init__(
self,
params: list[KernelParam],
result_idx: list[int],
target: str | Target,
......@@ -63,7 +66,8 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
device_kernel_source: str | None = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
compile_flags: list[str] | None = None,
):
"""Initialize the adapter with the given TIR function or module.
Args:
......@@ -113,15 +117,13 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
if param in buffer_map:
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
(shape not in params)):
if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and (shape not in params):
dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and
(stride not in params)):
if isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and (stride not in params):
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map
......@@ -197,8 +199,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
# Validate input count strictly
expected_inputs = len(self.params) - len(self.result_idx)
if len(inputs) != expected_inputs:
raise ValueError(
f"Kernel expected {expected_inputs} inputs, but {len(inputs)} are provided.")
raise ValueError(f"Kernel expected {expected_inputs} inputs, but {len(inputs)} are provided.")
# Resolve the device used for outputs. Prefer the first tensor input's device
# if available, otherwise use PyTorch's current device.
......@@ -217,17 +218,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
for s in param_shapes[i]:
if isinstance(s, tir.Var):
for key in dynamic_symbolic_map:
if (str(s) == str(key)):
ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[
key]
if str(s) == str(key):
ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[key]
if ref_id == 2:
shape.append(inputs[ref_tensor_idx])
elif ref_id == 0:
shape.append(
tensor_list[ref_tensor_idx].shape[ref_shape_idx])
shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
elif ref_id == 1:
shape.append(
tensor_list[ref_tensor_idx].stride()[ref_shape_idx])
shape.append(tensor_list[ref_tensor_idx].stride()[ref_shape_idx])
else: # Already converted to Python int during initialization
shape.append(s)
......@@ -235,11 +233,11 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
out_device = current_device_functor()
if len(shape) == 0:
param_name = self.params[i].name if hasattr(self.params[i],
'name') else f'parameter_{i}'
param_name = self.params[i].name if hasattr(self.params[i], "name") else f"parameter_{i}"
raise ValueError(
f"Cannot create output tensor (name={param_name}) - 0-dimensional tensors are not supported. "
f"Expected shape: {shape}")
f"Expected shape: {shape}"
)
tensor = torch.empty(*shape, dtype=dtype, device=out_device)
else:
tensor = inputs[ins_idx]
......@@ -256,7 +254,8 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
return func
@classmethod
def from_database(cls,
def from_database(
cls,
params: list[TensorType],
result_idx: list[int],
target: str,
......@@ -266,7 +265,8 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
kernel_lib_path: str,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
compile_flags: list[str] | None = None,
):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
......
......@@ -70,7 +70,6 @@ def get_annotated_mod(
target_host: str | Target | None = None,
model_type: Literal["device", "host", "all"] = "all",
) -> IRModule | tuple[IRModule, IRModule]:
# Validate model_type early
if model_type not in {"device", "host", "all"}:
raise ValueError(f"Invalid model type: {model_type}")
......@@ -95,21 +94,15 @@ def get_annotated_mod(
# Define dispatch dictionary for different model types
dispatch = {
"device":
lambda m: tir.transform.Filter(_is_device_call)(m),
"host":
lambda m: tir.transform.Filter(_is_host_call)(m),
"all":
lambda m: (tir.transform.Filter(_is_device_call)(m), tir.transform.Filter(_is_host_call)
(m)),
"device": lambda m: tir.transform.Filter(_is_device_call)(m),
"host": lambda m: tir.transform.Filter(_is_host_call)(m),
"all": lambda m: (tir.transform.Filter(_is_device_call)(m), tir.transform.Filter(_is_host_call)(m)),
}
return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.PrimExpr,
dtype_map: dict[str, str] | None = None,
ignore_cast: bool = False) -> str:
def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None, ignore_cast: bool = False) -> str:
"""
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
......@@ -169,8 +162,22 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE)
elif isinstance(
node,
(tvm.tir.Mul, tvm.tir.FloorDiv, tvm.tir.Add, tvm.tir.Sub, tvm.tir.FloorMod, tvm.tir.LT,
tvm.tir.LE, tvm.tir.GT, tvm.tir.GE, tvm.tir.EQ, tvm.tir.NE, tvm.tir.And, tvm.tir.Or)):
(
tvm.tir.Mul,
tvm.tir.FloorDiv,
tvm.tir.Add,
tvm.tir.Sub,
tvm.tir.FloorMod,
tvm.tir.LT,
tvm.tir.LE,
tvm.tir.GT,
tvm.tir.GE,
tvm.tir.EQ,
tvm.tir.NE,
tvm.tir.And,
tvm.tir.Or,
),
):
op_map = {
tvm.tir.Mul: "*",
tvm.tir.FloorDiv: "/",
......@@ -222,10 +229,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
return next(iter(node_to_result_map[expr]), "")
def maybe_desc_name(name: str,
matches: list[str],
i: int,
desc_name_map: dict[str, str] | None = None) -> bool:
def maybe_desc_name(name: str, matches: list[str], i: int, desc_name_map: dict[str, str] | None = None) -> bool:
"""
Check if a parameter name corresponds to a TMA descriptor.
......@@ -290,8 +294,7 @@ def parse_function_call_args(
else:
call_args.append(match)
if desc_name_var_map is not None and function_params is not None:
assert len(call_args) <= len(function_params), \
f"Too many arguments: {len(call_args)} > {len(function_params)}"
assert len(call_args) <= len(function_params), f"Too many arguments: {len(call_args)} > {len(function_params)}"
desc_name_var_map[match] = function_params[len(call_args) - 1]
return call_args
......@@ -300,12 +303,7 @@ def parse_function_call_args(
class TMADescriptorParams:
"""Parsed TMA descriptor parameters."""
def __init__(self,
handle_name: str,
dtype: str,
tensor_rank: int,
global_address: Any,
is_img2col: bool = False):
def __init__(self, handle_name: str, dtype: str, tensor_rank: int, global_address: Any, is_img2col: bool = False):
self.handle_name = handle_name
self.dtype = dtype
self.tensor_rank = tensor_rank
......@@ -355,22 +353,19 @@ def parse_tma_descriptor_args(
results = []
for handle_name, _ in desc_name_map.items():
assert handle_name in desc_name_var_map, \
f"Handle name {handle_name} not found in desc_name_var_map"
assert handle_name in desc_name_var_map, f"Handle name {handle_name} not found in desc_name_var_map"
desc_var = desc_name_var_map[handle_name]
assert desc_var in tma_descriptor_args, \
f"TMA descriptor {desc_var} not found in {tma_descriptor_args}"
assert desc_var in tma_descriptor_args, f"TMA descriptor {desc_var} not found in {tma_descriptor_args}"
args = tma_descriptor_args[desc_var]
# Skip __tvm_tensormap_create_tiled and second element (like CUDA version)
if len(args) < 3:
raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
raise ValueError(f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
tma_create_str, _, dtype, tensor_rank, global_address, *remaining_args = args
is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col")
is_img2col = tma_create_str.value == "__tvm_tensormap_create_im2col"
# Convert basic fields
dtype = pythonic_expr_func(dtype)
......@@ -386,60 +381,45 @@ def parse_tma_descriptor_args(
# Tiled mode
expected_args_len = 4 * tensor_rank + 4
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
raise ValueError(
f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}"
)
# Extract dimensions and strides
params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]]
params.global_stride = [
pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank]
]
params.box_dim = [
pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank]
]
params.element_strides = [
pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank]
]
params.global_stride = [pythonic_expr_func(i) for i in remaining_args[tensor_rank : 2 * tensor_rank]]
params.box_dim = [pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank : 3 * tensor_rank]]
params.element_strides = [pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank : 4 * tensor_rank]]
# Extract remaining parameters
try:
interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank : 4 * tensor_rank + 4]
params.interleave = pythonic_expr_func(interleave)
params.swizzle = pythonic_expr_func(swizzle)
params.l2_promotion = pythonic_expr_func(l2_promotion)
params.oob_fill = pythonic_expr_func(oob_fill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
raise ValueError("Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)") from e
else:
# Im2col mode
expected_args_len = 5 * tensor_rank + 2
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
raise ValueError(
f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}"
)
# Extract dimensions and strides
params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]]
params.global_stride = [
pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank]
]
params.element_strides = [
pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank]
]
params.lower_corner = [
pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
]
params.upper_corner = [
pythonic_expr_func(i)
for i in remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
]
params.global_stride = [pythonic_expr_func(i) for i in remaining_args[tensor_rank : 2 * tensor_rank]]
params.element_strides = [pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank : 3 * tensor_rank]]
params.lower_corner = [pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank : 4 * tensor_rank - 2]]
params.upper_corner = [pythonic_expr_func(i) for i in remaining_args[4 * tensor_rank - 2 : 5 * tensor_rank - 4]]
# Extract remaining parameters
try:
smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = \
remaining_args[5 * tensor_rank - 4:5 * tensor_rank + 2]
smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = remaining_args[
5 * tensor_rank - 4 : 5 * tensor_rank + 2
]
params.smem_box_pixel = pythonic_expr_func(smem_box_pixel)
params.smem_box_channel = pythonic_expr_func(smem_box_channel)
params.interleave = pythonic_expr_func(interleave)
......
......@@ -4,9 +4,18 @@ from tilelang import tvm as tvm
from typing import Any
from tvm import IRModule
from tvm.target import Target
from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target,
is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr,
parse_function_call_args, parse_tma_descriptor_args)
from .utils import (
is_metal_target,
match_declare_kernel,
match_declare_kernel_cpu,
is_cuda_target,
is_hip_target,
is_cpu_target,
get_annotated_mod,
pythonic_expr,
parse_function_call_args,
parse_tma_descriptor_args,
)
import re
import logging
import textwrap
......@@ -129,7 +138,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """
class BaseWrapper(ABC):
@abstractmethod
def wrap(self, *args, **kwargs):
raise NotImplementedError
......@@ -163,13 +171,15 @@ class TLCUDASourceWrapper:
host_mod: IRModule | None = None
pass_configs: dict[str, Any] | None = None
def __init__(self,
def __init__(
self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
pass_configs: dict[str, Any] | None = None,
):
self.mod = scheduled_ir_module
self.target = target
self.source = source
......@@ -211,15 +221,16 @@ class TLCUDASourceWrapper:
for param in self.prim_func.params:
if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param]
function_args.append({
function_args.append(
{
"name": buffer.data.name,
"type": self._lookup_type(buffer.dtype) + "* __restrict__",
})
}
)
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]:
......@@ -256,38 +267,40 @@ class TLCUDASourceWrapper:
# Identify the start of the function body to insert arguments
index = code.index("{", index)
block_str = f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})"
grid_str = f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})"
block_str = (
f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})"
)
grid_str = (
f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})"
)
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map
if self.use_cooperative_groups[function_name]:
args_list = parse_function_call_args(declaration, function_args, function_params,
desc_name_map, desc_name_var_map)
assert len(function_params) == len(
args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
args_list = parse_function_call_args(declaration, function_args, function_params, desc_name_map, desc_name_var_map)
assert len(function_params) == len(args_list), (
f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
)
args_array = [f"(void*)&{arg}" for arg in args_list]
call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n"
kernel_launch_code += call_args
# Using cudaLaunchCooperativeKernel to launch the kernel
kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format(
function_name, grid_str, block_str, function_name + "_args", smem_str)
function_name, grid_str, block_str, function_name + "_args", smem_str
)
else:
args_list = parse_function_call_args(declaration, function_args, function_params,
desc_name_map, desc_name_var_map)
assert len(function_params) == len(
args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
args_list = parse_function_call_args(declaration, function_args, function_params, desc_name_map, desc_name_var_map)
assert len(function_params) == len(args_list), (
f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
)
call_args = ", ".join(args_list)
kernel_launch_code += f"\t{function_name}<<<{grid_str}, {block_str}, {smem_str}, stream>>>({call_args});\n"
kernel_launch_code += f"\tTILELANG_CHECK_LAST_ERROR(\"{function_name}\");\n"
kernel_launch_code += f'\tTILELANG_CHECK_LAST_ERROR("{function_name}");\n'
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE
init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map,
desc_name_var_map)
init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map)
kernel_launch_code = init_tma_descriptor_args + kernel_launch_code
# Wrap the kernel dispatch logic in an external C function
......@@ -298,46 +311,63 @@ class TLCUDASourceWrapper:
if function_name not in self.l2_persistent_map:
return ""
init_l2_persistent_map = ""
for buffer_name, (hit_ratio,
size_in_bytes) in self.l2_persistent_map[function_name].items():
for buffer_name, (hit_ratio, size_in_bytes) in self.l2_persistent_map[function_name].items():
# get persisting_l2_cache_max_size
from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size
persisting_l2_cache_max_size = get_persisting_l2_cache_max_size()
try:
num_bytes = min(size_in_bytes, persisting_l2_cache_max_size)
except Exception:
# as size_in_bytes maybe a symbolic expression
num_bytes = persisting_l2_cache_max_size
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(
buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes))
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes))
return init_l2_persistent_map
def generate_tma_descriptor_args(self, desc_name_map: dict[str, str],
desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
tma_descripter_init = ""
if self.tma_descriptor_args is None:
return tma_descripter_init
# Parse TMA descriptor arguments using the common utility
parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map,
desc_name_var_map, self._pythonic_expr)
parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr)
# Generate C++ code from parsed parameters
for params in parsed_params:
if not params.is_img2col:
tma_descripter_init += TMA_DESC_INIT_FUNC.format(
params.handle_name, params.dtype, params.tensor_rank, params.global_address,
",".join(params.global_dim), ",".join(params.global_stride),
",".join(params.box_dim), ",".join(params.element_strides), params.interleave,
params.swizzle, params.l2_promotion, params.oob_fill)
params.handle_name,
params.dtype,
params.tensor_rank,
params.global_address,
",".join(params.global_dim),
",".join(params.global_stride),
",".join(params.box_dim),
",".join(params.element_strides),
params.interleave,
params.swizzle,
params.l2_promotion,
params.oob_fill,
)
else:
tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format(
params.handle_name, params.dtype, params.tensor_rank, params.global_address,
",".join(params.global_dim), ",".join(params.global_stride),
",".join(params.element_strides), ",".join(params.lower_corner),
",".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel,
params.interleave, params.swizzle, params.l2_promotion, params.oob_fill)
params.handle_name,
params.dtype,
params.tensor_rank,
params.global_address,
",".join(params.global_dim),
",".join(params.global_stride),
",".join(params.element_strides),
",".join(params.lower_corner),
",".join(params.upper_corner),
params.smem_box_channel,
params.smem_box_pixel,
params.interleave,
params.swizzle,
params.l2_promotion,
params.oob_fill,
)
return tma_descripter_init
......@@ -347,9 +377,8 @@ class TLCUDASourceWrapper:
device_mod, host_mod = get_annotated_mod(self.mod, self.target)
self.device_mod = device_mod
self.host_mod = host_mod
assert (len(self.device_mod.functions)
>= 1), "Device module should have at least one function."
assert (len(self.host_mod.functions) == 1), "Only support one function in host module."
assert len(self.device_mod.functions) >= 1, "Device module should have at least one function."
assert len(self.host_mod.functions) == 1, "Only support one function in host module."
block_info_map = {}
grid_info_map = {}
......@@ -438,8 +467,7 @@ class TLCUDASourceWrapper:
for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items():
if dynamic_smem_buf is not None:
# Format the cudaFuncSetAttribute call for dynamic shared memory
call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY.format(
function_name, dynamic_smem_buf)
call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY.format(function_name, dynamic_smem_buf)
# Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
......@@ -466,17 +494,14 @@ class TLCUDASourceWrapper:
def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
nonlocal function_params
if isinstance(node, tvm.tir.Call):
if not (hasattr(node, "op") and
node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
return
args = node.args
if not args or args[0] != fn:
return
if len(args) < 1 + param_cnt:
raise AssertionError(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params = args[1:1 + param_cnt]
raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters")
function_params = args[1 : 1 + param_cnt]
post_order_visit(self.host_func.body, visitor)
assert function_params is not None, "function_params should not be None"
......@@ -564,13 +589,15 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
"uchar": "uint8_t",
}
def __init__(self,
def __init__(
self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
pass_configs: dict[str, Any] | None = None,
):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def get_init_func(self):
......@@ -580,8 +607,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items():
if dynamic_smem_buf is not None:
# Format the cudaFuncSetAttribute call for dynamic shared memory
call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP.format(
function_name, dynamic_smem_buf)
call_str += PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP.format(function_name, dynamic_smem_buf)
# Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
......@@ -623,13 +649,15 @@ class TLCPUSourceWrapper:
host_mod: IRModule | None = None
pass_configs: dict[str, Any] | None = None
def __init__(self,
def __init__(
self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
pass_configs: dict[str, Any] | None = None,
):
self.mod = scheduled_ir_module
self.target = target
self.source = source
......@@ -658,15 +686,16 @@ class TLCPUSourceWrapper:
for param in self.prim_func.params:
if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param]
function_args.append({
function_args.append(
{
"name": buffer.name,
"type": self._lookup_type(buffer.dtype) + "*",
})
}
)
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
raise ValueError(f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym, dyn_sym_dtype in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": self._lookup_type(dyn_sym_dtype)})
......@@ -686,7 +715,6 @@ class TLCPUSourceWrapper:
_call_str = """"""
for function_name, _ in function_informations.items():
# Find the location of the global kernel function in the code
index = match_declare_kernel_cpu(code, function_name + "(")
......@@ -706,8 +734,8 @@ class TLCPUSourceWrapper:
def parse_source_information(self):
with tvm.transform.PassContext(opt_level=3, config=self.pass_configs):
device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) >= 1), "Device module should have at least one function."
assert (len(host_mod.functions) == 1), "Only support one function in host module."
assert len(device_mod.functions) >= 1, "Device module should have at least one function."
assert len(host_mod.functions) == 1, "Only support one function in host module."
function_names = []
for g_var, _ in device_mod.functions.items():
......@@ -767,14 +795,15 @@ class TLCPUSourceWrapper:
class TLMetalSourceWrapper:
def __init__(self,
def __init__(
self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
pass_configs: dict[str, Any] | None = None,
):
self.mod = scheduled_ir_module
self.target = target
self.source = source
......@@ -792,6 +821,7 @@ class TLWrapper(BaseWrapper):
"""
A wrapper class for the TileLang backend.
"""
device_mod: IRModule | None = None
host_mod: IRModule | None = None
pass_configs: dict[str, Any] | None = None
......@@ -836,12 +866,12 @@ class TLWrapper(BaseWrapper):
target=self.target,
device_mod=self.device_mod,
host_mod=self.host_mod,
pass_configs=self.pass_configs)
pass_configs=self.pass_configs,
)
return wrapper.lib_code
class TLPyWrapper(TLWrapper):
def __init__(self, target: Target):
super().__init__(target)
......@@ -849,6 +879,7 @@ class TLPyWrapper(TLWrapper):
# assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if is_cuda_target(self.target):
from tilelang.jit.adapter.nvrtc import TLNVRTCSourceWrapper
wrapper_class = TLNVRTCSourceWrapper
else:
raise ValueError(f"Unsupported target for NVRTC backend: {self.target}")
......@@ -858,5 +889,6 @@ class TLPyWrapper(TLWrapper):
target=self.target,
device_mod=self.device_mod,
host_mod=self.host_mod,
pass_configs=self.pass_configs)
pass_configs=self.pass_configs,
)
return wrapper.host_func, wrapper.function_names
......@@ -46,6 +46,7 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T
# Drop NVRTC if not importable
try:
from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy
if not is_nvrtc_available and "nvrtc" in allowed:
allowed = [b for b in allowed if b != "nvrtc"]
except Exception:
......@@ -89,12 +90,14 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str:
if req not in allowed_all:
raise ValueError(
f"Invalid execution backend '{requested}' for target '{_target_kind(target)}'. "
f"Allowed: {_format_options(allowed_all)}. Tip: use execution_backend='auto'.")
f"Allowed: {_format_options(allowed_all)}. Tip: use execution_backend='auto'."
)
# Promote to availability-aware set for nicer errors (e.g., nvrtc not installed)
if req not in allowed_avail:
raise ValueError(
f"Execution backend '{requested}' requires extra dependencies and is not available now. "
f"Try one of: {_format_options(allowed_avail)}.")
f"Try one of: {_format_options(allowed_avail)}."
)
return req
from __future__ import annotations
from typing import Any, Callable, Generic, Literal, TypeVar
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
......@@ -14,8 +15,7 @@ import tilelang
from tilelang import tvm
from tilelang import env
from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter,
TVMFFIKernelAdapter, MetalKernelAdapter)
from tilelang.jit.adapter import BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, TVMFFIKernelAdapter, MetalKernelAdapter
from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import determine_target
from tilelang.contrib import nvcc as tl_nvcc
......@@ -24,8 +24,8 @@ import os
logger = logging.getLogger(__name__)
_P = ParamSpec('_P')
_T = TypeVar('_T')
_P = ParamSpec("_P")
_T = TypeVar("_T")
class JITKernel(Generic[_P, _T]):
......@@ -41,6 +41,7 @@ class JITKernel(Generic[_P, _T]):
torch_function : Callable
The compiled function that can be invoked as a PyTorch-compatible function.
"""
prim_func: PrimFunc = None
artifact: CompiledArtifact = None
adapter: BaseKernelAdapter = None
......@@ -111,9 +112,7 @@ class JITKernel(Generic[_P, _T]):
if execution_backend == "cython":
from tilelang.contrib.cc import get_cplus_compiler
assert (
get_cplus_compiler() is not None
), "Cython backend requires a C++ compiler, please install or use other backends."
assert get_cplus_compiler() is not None, "Cython backend requires a C++ compiler, please install or use other backends."
if from_database:
return
......@@ -200,8 +199,7 @@ class JITKernel(Generic[_P, _T]):
"""
return self.torch_function(*args, **kwds)
def _compile_and_create_adapter(self, tilelang_func: PrimFunc,
out_idx: list[int]) -> BaseKernelAdapter:
def _compile_and_create_adapter(self, tilelang_func: PrimFunc, out_idx: list[int]) -> BaseKernelAdapter:
"""
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
......@@ -233,7 +231,8 @@ class JITKernel(Generic[_P, _T]):
target=target,
target_host=target_host,
enable_host_codegen=enable_host_codegen,
enable_device_compile=enable_device_compile)
enable_device_compile=enable_device_compile,
)
self.artifact = artifact
......@@ -241,7 +240,7 @@ class JITKernel(Generic[_P, _T]):
if execution_backend == "tvm_ffi":
# Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack.
# But we need to ensure that the runtime is enabled and the runtime module is not None.
assert (artifact.rt_mod is not None), "tvm_ffi backend requires a runtime module."
assert artifact.rt_mod is not None, "tvm_ffi backend requires a runtime module."
adapter = TVMFFIKernelAdapter(
params=artifact.params,
result_idx=out_idx,
......@@ -283,6 +282,7 @@ class JITKernel(Generic[_P, _T]):
)
elif execution_backend == "nvrtc":
from tilelang.jit.adapter import NVRTCKernelAdapter
adapter = NVRTCKernelAdapter(
params=artifact.params,
result_idx=out_idx,
......@@ -315,7 +315,8 @@ class JITKernel(Generic[_P, _T]):
return adapter
def _create_adapter_from_database(self,
def _create_adapter_from_database(
self,
params: list[KernelParam],
result_idx: list[int] | int,
target: str | Target,
......@@ -324,7 +325,8 @@ class JITKernel(Generic[_P, _T]):
device_kernel_source: str,
kernel_lib_path: str,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None) -> BaseKernelAdapter:
compile_flags: list[str] | None = None,
) -> BaseKernelAdapter:
target = self.target
execution_backend = self.execution_backend
......@@ -366,6 +368,7 @@ class JITKernel(Generic[_P, _T]):
)
elif execution_backend == "nvrtc":
from tilelang.jit.adapter import NVRTCKernelAdapter
adapter = NVRTCKernelAdapter.from_database(
params=params,
result_idx=result_idx,
......@@ -402,8 +405,7 @@ class JITKernel(Generic[_P, _T]):
"""
return cls(func=tilelang_func, **kwargs)
def get_profiler(self,
tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler:
def get_profiler(self, tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler:
"""
Creates a profiler to benchmark the compiled runtime module.
......@@ -417,8 +419,7 @@ class JITKernel(Generic[_P, _T]):
Profiler
A Profiler instance for benchmarking the runtime module.
"""
return Profiler(self.params, self.out_idx,
tensor_supply_type).with_default_adapter(self.adapter)
return Profiler(self.params, self.out_idx, tensor_supply_type).with_default_adapter(self.adapter)
def get_kernel_source(self, kernel_only: bool = True) -> str:
"""
......@@ -507,21 +508,19 @@ class JITKernel(Generic[_P, _T]):
dir_path = os.path.dirname(kernel_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(kernel_path, 'w') as f:
with open(kernel_path, "w") as f:
f.write(self.get_kernel_source())
if host_path is not None:
dir_path = os.path.dirname(host_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(host_path, 'w') as f:
with open(host_path, "w") as f:
f.write(self.get_host_source())
except Exception as e:
logger.error(f"Failed to export sources: {e}")
# Backward compatibility alias (deprecated)
def print_source_code(self,
which: Literal["kernel", "host", "both"] = "kernel",
file: str | None = None) -> None:
def print_source_code(self, which: Literal["kernel", "host", "both"] = "kernel", file: str | None = None) -> None:
"""
Deprecated: use show_source() or export_sources() instead.
......@@ -541,16 +540,14 @@ class JITKernel(Generic[_P, _T]):
>>> # Old API (still works but deprecated)
>>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
"""
logger.warning(
"print_source_code is deprecated; use show_source() or export_sources() instead.")
logger.warning("print_source_code is deprecated; use show_source() or export_sources() instead.")
if file is not None:
# Historical behavior wrote only kernel source when file provided
self.export_sources(kernel_path=file)
else:
self.show_source(which=which)
def update_tuner_result(self, latency: float, config: dict[str, Any],
ref_latency: float) -> JITKernel:
def update_tuner_result(self, latency: float, config: dict[str, Any], ref_latency: float) -> JITKernel:
"""
Updates the tuning results for this kernel.
......@@ -651,8 +648,7 @@ class JITKernel(Generic[_P, _T]):
verbose = self.verbose
# Ensure target is set so nvcc picks correct arch via Target.current()
with self.target:
return tl_nvcc.get_ptx_from_source(
code, compile_flags=self.compile_flags, verbose=verbose)
return tl_nvcc.get_ptx_from_source(code, compile_flags=self.compile_flags, verbose=verbose)
def show_ptx(self) -> None:
"""
......@@ -714,8 +710,7 @@ class JITKernel(Generic[_P, _T]):
if verbose is None:
verbose = self.verbose
with self.target:
return tl_nvcc.get_sass_from_source(
code, compile_flags=self.compile_flags, verbose=verbose)
return tl_nvcc.get_sass_from_source(code, compile_flags=self.compile_flags, verbose=verbose)
def show_sass(self) -> None:
"""
......
"""The language interface for tl programs."""
from __future__ import annotations
# from .parser import *
......@@ -102,7 +103,10 @@ from .utils import index_to_coordinates # noqa: F401
from .symbolics import dynamic, symbolic # noqa: F401
from .annotations import ( # noqa: F401
use_swizzle, annotate_layout, annotate_safe_value, annotate_l2_hit_ratio,
use_swizzle,
annotate_layout,
annotate_safe_value,
annotate_l2_hit_ratio,
)
......
......@@ -13,8 +13,10 @@ Available allocation functions:
Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope.
"""
from __future__ import annotations
from typing import TypeVar, overload, Literal, Callable
# Python 3.9 compatibility for advanced typing features (PEP 646)
try:
from typing import TypeVarTuple, Unpack # type: ignore[attr-defined]
......@@ -30,13 +32,11 @@ from .v2.dtypes import dtype as tl_dtype
from .v2.builder import OutTensor
from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer
_Shapes = TypeVarTuple('_Shapes')
_DType = TypeVar('_DType')
_Shapes = TypeVarTuple("_Shapes")
_DType = TypeVar("_DType")
def alloc_shared(shape: tuple[Unpack[_Shapes]],
dtype: _DType,
scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]:
def alloc_shared(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a shared memory buffer for inter-thread communication.
Args:
......@@ -54,9 +54,7 @@ def alloc_shared(shape: tuple[Unpack[_Shapes]],
return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_local(shape: tuple[Unpack[_Shapes]],
dtype: _DType,
scope="local") -> LocalBuffer[Callable[[Unpack[_Shapes]]], _DType]:
def alloc_local(shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="local") -> LocalBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a local memory buffer for thread-private storage.
Args:
......@@ -70,9 +68,9 @@ def alloc_local(shape: tuple[Unpack[_Shapes]],
return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_fragment(shape: tuple[Unpack[_Shapes]],
dtype: _DType,
scope="local.fragment") -> FragmentBuffer[Callable[[Unpack[_Shapes]]], _DType]:
def alloc_fragment(
shape: tuple[Unpack[_Shapes]], dtype: _DType, scope="local.fragment"
) -> FragmentBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a fragment memory buffer for specialized operations.
Args:
......@@ -87,16 +85,11 @@ def alloc_fragment(shape: tuple[Unpack[_Shapes]],
@overload
def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = 'local.var') -> Buffer:
...
def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = "local.var") -> Buffer: ...
@overload
def alloc_var(dtype: str,
scope: str = 'local.var',
*,
init: PrimExpr | int | float | None = None) -> Buffer:
...
def alloc_var(dtype: str, scope: str = "local.var", *, init: PrimExpr | int | float | None = None) -> Buffer: ...
def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
......@@ -142,8 +135,7 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
raise TypeError("Scope must be provided as a string in alloc_var.")
parsed_scope = parsed_scope_arg
elif len(args) > 2:
raise TypeError(
f"alloc_var expected at most 3 positional arguments but got {len(args) + 1}.")
raise TypeError(f"alloc_var expected at most 3 positional arguments but got {len(args) + 1}.")
if not isinstance(parsed_scope, str):
raise TypeError("Scope must be a string in alloc_var.")
......@@ -274,13 +266,10 @@ def alloc_tcgen05_instr_desc(dtype: str = "uint32"):
@overload
def empty(shape: tuple[Unpack[_Shapes]],
dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
...
def empty(shape: tuple[Unpack[_Shapes]], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ...
def empty(*shape: Unpack[_Shapes],
dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
def empty(*shape: Unpack[_Shapes], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
if len(shape) == 1 and isinstance(shape[0], (tuple, list)):
return OutTensor(shape[0], dtype)
elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str):
......@@ -288,4 +277,4 @@ def empty(*shape: Unpack[_Shapes],
elif all([isinstance(x, (int, PrimExpr)) for x in shape]):
return OutTensor(shape, dtype)
else:
raise RuntimeError(f'Invalid shape {shape}')
raise RuntimeError(f"Invalid shape {shape}")
"""Annotation helpers exposed on the TileLang language surface."""
from typing import Callable
from tilelang.layout import Layout
......
......@@ -17,6 +17,7 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""Package tvm.script.ir_builder.tir"""
from .ir import * # noqa: F401
from .ir import boolean as bool # noqa: F401
from .ir import buffer as Buffer # noqa: F401
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment