Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
from tilelang import tvm as tvm
from tvm.ir.base import Node
from tvm.runtime import Scriptable
import tvm.ffi
import tvm_ffi
from tvm.target import Target
from tilelang import _ffi_api
@tvm.ffi.register_object("tl.Fill")
@tvm_ffi.register_object("tl.Fill")
class Fill(Node, Scriptable):
...
@tvm.ffi.register_object("tl.AtomicAdd")
@tvm_ffi.register_object("tl.AtomicAdd")
class AtomicAdd(Node, Scriptable):
...
@tvm.ffi.register_object("tl.Copy")
@tvm_ffi.register_object("tl.Copy")
class Copy(Node, Scriptable):
...
@tvm.ffi.register_object("tl.Conv2DIm2Col")
@tvm_ffi.register_object("tl.Conv2DIm2Col")
class Conv2DIm2ColOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.GemmWarpPolicy")
@tvm_ffi.register_object("tl.GemmWarpPolicy")
class GemmWarpPolicy(Node, Scriptable):
policy_type: int
m_warp: int
......@@ -39,41 +39,41 @@ class GemmWarpPolicy(Node, Scriptable):
return self.m_warp, self.n_warp
@tvm.ffi.register_object("tl.Gemm")
@tvm_ffi.register_object("tl.Gemm")
class Gemm(Node, Scriptable):
...
@tvm.ffi.register_object("tl.GemmSP")
@tvm_ffi.register_object("tl.GemmSP")
class GemmSP(Node, Scriptable):
...
@tvm.ffi.register_object("tl.FinalizeReducerOp")
@tvm_ffi.register_object("tl.FinalizeReducerOp")
class FinalizeReducerOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ParallelOp")
@tvm_ffi.register_object("tl.ParallelOp")
class ParallelOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ReduceOp")
@tvm_ffi.register_object("tl.ReduceOp")
class ReduceOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.CumSumOp")
@tvm_ffi.register_object("tl.CumSumOp")
class CumSumOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.RegionOp")
@tvm_ffi.register_object("tl.RegionOp")
class RegionOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ReduceType")
@tvm_ffi.register_object("tl.ReduceType")
class ReduceType(Node, Scriptable):
...
......@@ -5,15 +5,25 @@ kernel adapter using TVM.
"""
from __future__ import annotations
from dataclasses import dataclass
import inspect
from typing import (
Any,
Callable,
Generic,
TypeVar,
overload,
Literal,
)
from collections.abc import Iterable
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tilelang import tvm as tvm
from tilelang.language.v2 import PrimFunc
from tilelang.jit.adapter.utils import is_metal_target
from tvm.tir import PrimFunc
from tvm.target import Target
from tilelang.jit.kernel import JITKernel
......@@ -21,14 +31,20 @@ from tilelang.utils.target import determine_target
from tilelang.cache import cached
from os import path, makedirs
from logging import getLogger
import functools
from tilelang.jit.param import Kernel, _P, _RProg
from tilelang.jit.param import Kernel
import concurrent.futures
from tqdm.auto import tqdm
logger = getLogger(__name__)
_P = ParamSpec('_P')
_KP = ParamSpec('_KP')
_T = TypeVar('_T')
def compile(
func: PrimFunc = None,
func: PrimFunc[_KP, _T] = None,
out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: str | Target = "auto",
......@@ -36,7 +52,7 @@ def compile(
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | str | None = None,
) -> JITKernel:
) -> JITKernel[_KP, _T]:
"""
Compile the given TileLang PrimFunc with TVM and build a JITKernel.
Parameters
......@@ -79,159 +95,208 @@ def compile(
)
class _JitImplementation:
def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: str | Target = "auto",
target_host: str | Target | None = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | str | None = None,
num_workers: int = None,
ignore_error: bool = False) -> list[JITKernel[_KP, _T]]:
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
----------
funcs : Iterable[tvm.tir.PrimFunc]
The TileLang TIR functions to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
Execution backend to use for kernel execution (default: "cython").
target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto").
target_host : Union[str, Target], optional
Target host for cross-compilation (default: None).
verbose : bool, optional
Whether to enable verbose output (default: False).
pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext.
Refer to `tilelang.transform.PassConfigKey` for supported options.
"""
with concurrent.futures.ThreadPoolExecutor(num_workers, 'tl-par-comp') as executor:
futures = []
future_map = {}
for i, func in enumerate(funcs):
future = executor.submit(
compile,
func=func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
future_map[future] = i
futures.append(future)
results = [... for _ in futures]
for future in tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
desc="Parallel Compiling",
):
idx = future_map[future]
if ignore_error:
try:
results[idx] = future.result()
except Exception as e:
logger.warning(f"Error compiling function at index {idx}: {e}")
results[idx] = None
else:
results[idx] = future.result()
return results
return results
@dataclass
class JITImpl(Generic[_P, _KP, _T]):
func: Callable[_P, _T] | PrimFunc[_KP, _T]
out_idx: list[int] | int | None
execution_backend: Literal["dlpack", "ctypes", "cython"]
target: str | Target
target_host: str | Target
execution_backend: Literal["dlpack", "ctypes", "cython"]
verbose: bool
pass_configs: dict[str, Any] | None
debug_root_path: str | None
compile_flags: list[str] | str | None
func_source: str
signature: inspect.Signature
def __init__(self,
out_idx: Any = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None):
"""
Initializes the JIT compiler decorator.
Parameters
----------
out_idx : Any, optional
Index(es) of the output tensors to return from the compiled kernel
(default: None, meaning all outputs are returned or determined by the kernel itself).
target : Union[str, Target], optional
Compilation target for TVM. Can be a string (e.g., "cuda", "llvm")
or a TVM Target object. If "auto", the target is determined automatically
(default: "auto").
target_host : Union[str, Target], optional
Target host for cross-compilation, similar to `target` (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
The backend used for kernel execution and argument passing.
"dlpack" is generally preferred for zero-copy tensor passing with compatible frameworks.
"ctypes" uses standard C types. "cython" uses Cython for potentially faster execution.
(default: "cython").
verbose : bool, optional
If True, enables verbose logging during compilation (default: False).
pass_configs : Optional[Dict[str, Any]], optional
A dictionary of configurations for TVM's pass context. These can fine-tune
the compilation process. Examples include "tir.disable_vectorize"
(default: None).
debug_root_path : Optional[str], optional
If provided, the compiled kernel's source code will be saved to a file
in this directory. This is useful for debugging the generated code.
If None, no debug information is saved (default: None).
If a relative path is given, it's made absolute relative to the project root
or current working directory.
compile_flags : Optional[Union[List[str], str]], optional
Additional compilation flags to pass to the compiler.
If None, no additional compilation flags are passed (default: None).
"""
self.out_idx = out_idx
self.execution_backend = execution_backend
self.target = target
self.target_host = target_host
self.verbose = verbose
self.pass_configs = pass_configs
self.compile_flags = compile_flags
# Corrected debug_root_path handling
self.debug_root_path = debug_root_path
def __post_init__(self):
if self.debug_root_path is not None and not path.isabs(self.debug_root_path):
try:
base_path = path.dirname(path.dirname(path.dirname(__file__)))
self.debug_root_path = path.join(base_path, self.debug_root_path)
except NameError:
self.debug_root_path = path.abspath(self.debug_root_path)
self._kernel_cache: dict[tuple, Kernel] = {}
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, Kernel]]:
...
@overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]:
...
# Actual implementation of __call__
def __call__(
self,
func: Callable[_P, _RProg] # func is Union[Callable[_P, _RProg], PrimFunc] in original
) -> Callable[_P, Any]:
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
# Separate out the tuning parameters from the user's kwargs
tune_params = kwargs.pop('__tune_params', {})
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
return_compile_arguments = kwargs.pop('__return_compile_arguments', False)
if return_compile_arguments:
compile_args = {
'out_idx': self.out_idx,
'execution_backend': self.execution_backend,
'target': self.target,
'target_host': self.target_host,
'verbose': self.verbose,
'pass_configs': self.pass_configs,
'compile_flags': self.compile_flags,
}
return compile_args
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
if key not in self._kernel_cache:
# Ensure 'func' (the original user function) is used correctly
program_result_source = func
if isinstance(program_result_source, PrimFunc):
program_result = program_result_source
elif callable(program_result_source):
program_result = program_result_source(*args, **kwargs, **tune_params)
else:
raise ValueError(f"Invalid function type: {type(program_result_source)}")
kernel_result = compile(
program_result,
out_idx=self.out_idx,
execution_backend=self.execution_backend,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs,
compile_flags=self.compile_flags,
)
if self.debug_root_path:
func_name = getattr(func, '__name__', 'jit_kernel') # Use func for name
kernel_file = f'tilelang_jit_kernel_{func_name}.c'
program_file = f'tilelang_jit_program_{func_name}.py'
makedirs(self.debug_root_path, exist_ok=True)
with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
print(kernel_result.get_kernel_source(), file=f)
with open(path.join(self.debug_root_path, program_file), 'w') as f:
print(program_result.script(), file=f)
self._kernel_cache[key] = kernel_result
return self._kernel_cache[key]
return wrapper
def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]:
program_result_source = self.func
if isinstance(program_result_source, PrimFunc):
program_result = program_result_source
elif callable(program_result_source):
program_result = program_result_source(*args, **kwargs)
else:
raise ValueError(f"Invalid function type: {type(program_result_source)}")
return program_result
def par_compile(self,
configs: Iterable[dict[str, Any] | tuple[str, Any]],
num_workers: int = None,
ignore_error: bool = False) -> list[JITKernel[_KP, _T]]:
configs = list(configs)
funcs = []
for cfg in tqdm(configs, desc='Elaborating'):
if isinstance(cfg, tuple):
funcs.append(self.get_tir(*cfg))
elif isinstance(cfg, dict):
funcs.append(self.get_tir(**cfg))
else:
raise ValueError(f"Invalid config type: {type(cfg)}, expected tuple or dict.")
return par_compile(
funcs,
out_idx=self.out_idx,
execution_backend=self.execution_backend,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs,
compile_flags=self.compile_flags,
num_workers=num_workers,
ignore_error=ignore_error)
def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]:
func = self.get_tir(*args, **kwargs)
kernel_result = compile(
func,
out_idx=self.out_idx,
execution_backend=self.execution_backend,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs,
compile_flags=self.compile_flags,
)
if self.debug_root_path:
if isinstance(self.func, PrimFunc):
func_name = self.func.attrs['global_symbol']
else:
func_name = getattr(self.func, '__name__', 'jit_kernel')
kernel_file = f'tilelang_jit_kernel_{func_name}.c'
program_file = f'tilelang_jit_program_{func_name}.py'
makedirs(self.debug_root_path, exist_ok=True)
with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
print(kernel_result.get_kernel_source(), file=f)
with open(path.join(self.debug_root_path, program_file), 'w') as f:
print(func.script(), file=f)
return kernel_result
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]:
# Separate out the tuning parameters from the user's kwargs
tune_params = kwargs.pop('__tune_params', {})
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
return_compile_arguments = kwargs.pop('__return_compile_arguments', False)
if return_compile_arguments:
compile_args = {
'out_idx': self.out_idx,
'execution_backend': self.execution_backend,
'target': self.target,
'target_host': self.target_host,
'verbose': self.verbose,
'pass_configs': self.pass_configs,
'compile_flags': self.compile_flags,
}
return compile_args
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
if key not in self._kernel_cache:
self._kernel_cache[key] = self.compile(*args, **kwargs, **tune_params)
return self._kernel_cache[key]
@overload
def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T]:
...
@overload
def jit(
*, # Indicates subsequent arguments are keyword-only
out_idx: Any = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None
) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T]]:
...
def jit( # This is the new public interface
func: Callable[_P, _RProg] | PrimFunc | None = None,
func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only
out_idx: Any = None,
target: str | Target = "auto",
......@@ -275,32 +340,26 @@ def jit( # This is the new public interface
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
if callable(func):
# Case 1: Used as @jit (func_or_out_idx is the function, others are defaults)
# Create a default _JitImplementation instance and apply it to the function.
default_decorator = _JitImplementation(
out_idx=out_idx, # Explicitly None for the default case
target=target,
target_host=target_host,
def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]:
if isinstance(func, PrimFunc):
orig_func = func.orig_func
else:
orig_func = func
return JITImpl(
func,
out_idx=out_idx,
execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path,
compile_flags=compile_flags)
return default_decorator(func)
elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else:
# Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _JitImplementation instance with the provided/defaulted arguments.
# This instance is a decorator that will be applied to the function later.
configured_decorator = _JitImplementation(
out_idx=out_idx, # Pass along; could be an actual out_idx or None
target=target,
target_host=target_host,
execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path,
compile_flags=compile_flags)
return configured_decorator
compile_flags=compile_flags,
func_source=inspect.getsource(orig_func),
signature=inspect.signature(orig_func),
)
if func is not None:
return decorator(func)
else:
return decorator
......@@ -27,7 +27,11 @@ class MetalKernelAdapter(BaseKernelAdapter):
# compile_flags: Optional[List[str]] = None
):
self.kernel_global_source = kernel_global_source
self.kernel_name = func_or_mod.__name__ + '_kernel'
if isinstance(func_or_mod, tir.PrimFunc):
func_name = func_or_mod.attrs['global_symbol']
else:
func_name = func_or_mod.__name__
self.kernel_name = func_name + '_kernel'
self.verbose = verbose
self.block_info = [1, 1, 1]
......@@ -43,7 +47,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
self.grid_info["xyz".index(tag[-1])] = extent
break
else:
raise AssertionError(f'no kernel with name {func_or_mod.__name__}')
raise AssertionError(f'no kernel with name {func_name}')
# print(self.block_info, self.grid_info)
super().__init__(func_or_mod, result_idx=result_idx, params=params)
......
......@@ -257,6 +257,12 @@ class TLCUDASourceWrapper:
def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str:
return pythonic_expr(expr, self._TYPE_MAP)
def _lookup_type(self, dtype: str | Any) -> str:
key = dtype if isinstance(dtype, str) else str(dtype)
result = self._TYPE_MAP.get(key)
assert result is not None, f"Unsupported dtype {dtype}"
return result
def is_tma_descriptor_arg(self, arg_name: str) -> bool:
return arg_name in self.prim_func.buffer_map
......@@ -274,10 +280,10 @@ class TLCUDASourceWrapper:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.data.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
"type": self._lookup_type(buffer.dtype) + "* __restrict__",
})
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]})
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
......@@ -717,6 +723,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"float16": "ctypes.c_uint16",
"bfloat16": "ctypes.c_uint16",
"float8_e4m3": "ctypes.c_uint8",
"float8_e4m3fn": "ctypes.c_uint8",
"float8_e5m2": "ctypes.c_uint8",
"float64": "ctypes.c_double",
"int64": "ctypes.c_int64",
......@@ -753,7 +760,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"type": "ctypes.c_void_p",
})
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]})
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
......@@ -923,6 +930,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
"float16": "half_t",
"bfloat16": "bfloat16_t",
"float8_e4m3": "fp8_e4_t",
"float8_e4m3fn": "fp8_e4_t",
"float8_e5m2": "fp8_e5_t",
"float8_e4m3fnuz": "fp8_e4_t",
"e4m3fnuz_float8": "fp8_e4_t",
......@@ -969,16 +977,19 @@ class TLCPUSourceWrapper:
"float32": "float",
"float16": "half",
"int32": "int32_t",
"int8": "int8_t",
"uint8": "uint8_t",
"int16": "int16_t",
"uint16": "uint16_t",
"int64": "int64_t",
"uint64": "uint64_t",
"float64": "double",
"bool": "bool",
"uchar": "uchar",
}
INIT_FUNC = textwrap.dedent('''
#ifdef __cplusplus
extern "C"
#endif
int32_t init() {
return 0;
}
''')
# Use common init with error buffer and get_last_error for CPU backend as well
INIT_FUNC = PREDEF_INIT_FUNC.format("")
CALL_PREFIX = textwrap.dedent("""
#ifdef __cplusplus
......@@ -1014,6 +1025,12 @@ class TLCPUSourceWrapper:
self.libpath: str | None = None
self.lib_code: str | None = self.update_lib_code(source)
def _lookup_type(self, dtype: str | Any) -> str:
key = dtype if isinstance(dtype, str) else str(dtype)
result = self._TYPE_MAP.get(key)
assert result is not None, f"Unsupported dtype {dtype}"
return result
def create_call_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
......@@ -1025,10 +1042,10 @@ class TLCPUSourceWrapper:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "*",
"type": self._lookup_type(buffer.dtype) + "*",
})
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]})
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
......@@ -1093,8 +1110,8 @@ class TLCPUSourceWrapper:
return dynamic_symbolic_set
def get_cpu_init_func(self):
init_funcs = self.INIT_FUNC
return init_funcs
# Provide init() and get_last_error() for CPU backend
return self.INIT_FUNC
def update_lib_code(self, code: str):
# Update the library code with the given code string
......
from __future__ import annotations
from typing import Any, Callable, Literal
from tilelang.jit.adapter.utils import is_metal_target
from typing import Any, Callable, Generic, Literal, TypeVar
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tilelang.jit.adapter.utils import is_metal_target, is_cuda_target
from tvm.target import Target
from tvm.tir import PrimFunc
......@@ -13,12 +18,17 @@ from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, Cython
NVRTCKernelAdapter, TorchDLPackKernelAdapter, MetalKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import determine_target
from tilelang.contrib import nvcc as tl_nvcc
import logging
import os
logger = logging.getLogger(__name__)
_P = ParamSpec('_P')
_T = TypeVar('_T')
class JITKernel:
class JITKernel(Generic[_P, _T]):
"""
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
......@@ -170,7 +180,7 @@ class JITKernel:
instance.torch_function = instance.adapter.func
return instance
def __call__(self, *args: Any, **kwds: Any) -> Any:
def __call__(self, *args: _P.args, **kwds: _P.kwargs) -> _T:
"""
Invokes the compiled function with the given arguments.
......@@ -404,6 +414,110 @@ class JITKernel:
def run_once(self, func: Callable | None = None) -> None:
return self.get_profiler().run_once(func)
def show_source(self, which: Literal["kernel", "host", "both"] = "kernel") -> None:
"""
Print generated source code to stdout.
Parameters
----------
which : Literal["kernel", "host", "both"], optional
Select which source to print. Defaults to "kernel".
Examples
--------
>>> jit_kernel.show_source() # print kernel source
>>> jit_kernel.show_source("host") # print host source
>>> jit_kernel.show_source("both") # print both sources
"""
try:
if which == "kernel":
src = self.get_kernel_source()
print(src)
elif which == "host":
src = self.get_host_source()
# Host is generally C/C++
print(src)
elif which == "both":
print("===== Kernel Source =====")
ksrc = self.get_kernel_source()
print(ksrc)
print("===== Host Source =====")
hsrc = self.get_host_source()
print(hsrc)
else:
raise ValueError(f"Unknown option for 'which': {which}")
except Exception as e:
logger.error(f"Failed to show source code: {e}")
def export_sources(self, kernel_path: str | None = None, host_path: str | None = None) -> None:
"""
Export generated source code to files.
Parameters
----------
kernel_path : Optional[str]
Destination file path to write the kernel source. If None, skips writing kernel code.
host_path : Optional[str]
Destination file path to write the host source. If None, skips writing host code.
Examples
--------
>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> jit_kernel.export_sources(host_path="/tmp/host.cc")
>>> jit_kernel.export_sources(
... kernel_path="/tmp/kernel.cu",
... host_path="/tmp/host.cc",
... )
"""
if kernel_path is None and host_path is None:
raise ValueError("At least one of kernel_path or host_path must be provided.")
try:
if kernel_path is not None:
dir_path = os.path.dirname(kernel_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(kernel_path, 'w') as f:
f.write(self.get_kernel_source())
if host_path is not None:
dir_path = os.path.dirname(host_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(host_path, 'w') as f:
f.write(self.get_host_source())
except Exception as e:
logger.error(f"Failed to export sources: {e}")
# Backward compatibility alias (deprecated)
def print_source_code(self,
which: Literal["kernel", "host", "both"] = "kernel",
file: str | None = None) -> None:
"""
Deprecated: use show_source() or export_sources() instead.
Parameters
----------
which : Literal["kernel", "host", "both"], optional
Kept for backward compatibility with printing behavior.
file : Optional[str]
If provided, behaves like export_sources(kernel_path=file).
Examples
--------
>>> # New API (preferred)
>>> jit_kernel.show_source("both")
>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> # Old API (still works but deprecated)
>>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
"""
logger.warning(
"print_source_code is deprecated; use show_source() or export_sources() instead.")
if file is not None:
# Historical behavior wrote only kernel source when file provided
self.export_sources(kernel_path=file)
else:
self.show_source(which=which)
def update_tuner_result(self, latency: float, config: dict[str, Any],
ref_latency: float) -> JITKernel:
"""
......@@ -483,3 +597,131 @@ class JITKernel:
# Export the compiled kernel function to a shared library file.
self.rt_module.export_library(kernel_file)
def _get_ptx(self, verbose: bool | None = None) -> str:
"""
Compile and return PTX for the current kernel (CUDA only).
Parameters
----------
verbose : Optional[bool]
Whether to enable verbose NVRTC logs. Defaults to self.verbose.
Returns
-------
str
The compiled PTX text.
"""
if not is_cuda_target(self.target):
raise ValueError("PTX is only available for CUDA targets.")
# Prefer NVCC for PTX generation via contrib helper
code = self.get_kernel_source()
if verbose is None:
verbose = self.verbose
# Ensure target is set so nvcc picks correct arch via Target.current()
with self.target:
return tl_nvcc.get_ptx_from_source(
code, compile_flags=self.compile_flags, verbose=verbose)
def show_ptx(self) -> None:
"""
Print compiled PTX for the kernel (CUDA only).
Examples
--------
>>> jit_kernel.show_ptx()
"""
try:
ptx = self._get_ptx()
print(ptx)
except Exception as e:
logger.error(f"Failed to show PTX: {e}")
def export_ptx(self, path: str) -> None:
"""
Export compiled PTX to a file (CUDA only).
Parameters
----------
path : str
Destination file path to write PTX.
Examples
--------
>>> jit_kernel.export_ptx("/tmp/kernel.ptx")
"""
if not path:
raise ValueError("path must be provided to export PTX")
try:
ptx = self._get_ptx()
dir_path = os.path.dirname(path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(path, "w") as f:
f.write(ptx)
logger.info(f"PTX saved to {os.path.abspath(path)}")
except Exception as e:
logger.error(f"Failed to export PTX: {e}")
def _get_sass(self, verbose: bool | None = None) -> str:
"""
Compile and return SASS for the current kernel (CUDA only).
Parameters
----------
verbose : Optional[bool]
Whether to enable verbose tool logs. Defaults to self.verbose.
Returns
-------
str
The disassembled SASS text.
"""
if not is_cuda_target(self.target):
raise ValueError("SASS is only available for CUDA targets.")
code = self.get_kernel_source()
if verbose is None:
verbose = self.verbose
with self.target:
return tl_nvcc.get_sass_from_source(
code, compile_flags=self.compile_flags, verbose=verbose)
def show_sass(self) -> None:
"""
Print disassembled SASS for the kernel (CUDA only).
Examples
--------
>>> jit_kernel.show_sass()
"""
try:
sass = self._get_sass()
print(sass)
except Exception as e:
logger.error(f"Failed to show SASS: {e}")
def export_sass(self, path: str) -> None:
"""
Export disassembled SASS to a file (CUDA only).
Parameters
----------
path : str
Destination file path to write SASS.
Examples
--------
>>> jit_kernel.export_sass("/tmp/kernel.sass")
"""
if not path:
raise ValueError("path must be provided to export SASS")
try:
sass = self._get_sass()
dir_path = os.path.dirname(path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(path, "w") as f:
f.write(sass)
logger.info(f"SASS saved to {os.path.abspath(path)}")
except Exception as e:
logger.error(f"Failed to export SASS: {e}")
......@@ -8,9 +8,9 @@ from __future__ import annotations
# upstream tir script is fully compatible
from tvm.script.parser.tir import *
from . import overrides as _overrides # noqa: F401
from .tir import (
prim_func, # noqa: F401
)
# from .tir import prim_func, macro, # noqa: F401
from .v2 import * # noqa: F401
from .tir.ir import * # noqa: F401
from tilelang.layout import Layout, Fragment # noqa: F401
from .proxy import (
......@@ -23,9 +23,7 @@ from .proxy import (
SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401
)
from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401
from .persistent import Persistent # noqa: F401
from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401
from .math_intrinsics import * # noqa: F401
from .kernel import (
......@@ -46,9 +44,12 @@ from .allocate import (
alloc_tmem, # noqa: F401
alloc_reducer, # noqa: F401
alloc_descriptor, # noqa: F401
alloc_wgmma_desc, # noqa: F401
alloc_tcgen05_smem_desc, # noqa: F401
alloc_tcgen05_instr_desc, # noqa: F401
)
from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm, gemm_v2 # noqa: F401
from .gemm import GemmWarpPolicy, gemm, gemm_v1, gemm_v2 # noqa: F401
from .experimental.gemm_sp import gemm_sp # noqa: F401
from .fill import fill, clear # noqa: F401
from .reduce import (
......
......@@ -15,10 +15,13 @@ with the appropriate memory scope.
"""
from __future__ import annotations
from typing import overload, Literal
from tilelang import tvm as tvm
from tvm.script import tir as T
from tvm.tir import PrimExpr
from tvm.script.parser.tir import block_attr
from tvm.tir.buffer import Buffer
from tvm.tir.expr import FloatImm, IntImm
def alloc_shared(shape, dtype, scope="shared.dyn"):
......@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
return T.alloc_buffer(shape, dtype, scope=scope)
@overload
def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = 'local.var') -> Buffer:
...
@overload
def alloc_var(dtype: str,
scope: str = 'local.var',
*,
init: PrimExpr | int | float | None = None) -> Buffer:
...
def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
"""Allocate a single-element variable buffer.
......@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
init (PrimExpr, optional): The optional initializer value. When provided,
the generated code will initialize the variable with this value instead
of defaulting to zero.
Examples:
a = T.alloc_var('int32', 1) # var with init 1
a = T.alloc_var('int32', 'local.var') # var with local.var scope
a = T.alloc_var('int32', 1, 'local.var') # var with init 1 and local.var scope
a = T.alloc_var('int32', 'local.var', init=1) # var with init 1 and local.var scope
a = T.alloc_var('int32', init=1) # var with init 1 and local.var scope
Returns:
T.Buffer: A TVM buffer object allocated as a single-element variable
"""
......@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
buffer = T.alloc_buffer([1], dtype, scope=parsed_scope)
if parsed_init is not None:
block_attr({"tl.local_var_init": {buffer.data: parsed_init}})
if isinstance(parsed_init, (int, float, IntImm, FloatImm)):
block_attr({"tl.local_var_init": {buffer.data: parsed_init}})
else:
T.buffer_store(buffer, parsed_init, 0)
return buffer
......@@ -194,10 +218,40 @@ def alloc_reducer(shape, dtype, op="sum", replication=None):
return reducer
def alloc_descriptor(dtype="uint64", scope="local.descriptor"):
"""Allocate a descriptor buffer for wgmma and utcmma.
DescKind = Literal["wgmma", "tcgen05_smem", "tcgen05_instr"]
def alloc_descriptor(
kind: DescKind = "wgmma",
dtype: str = "uint64",
):
"""Allocate a descriptor buffer for WGMMA and TCGEN5.MMA.
Args:
kind: The descriptor kind, one of "wgmma", "tcgen05" ("utcmma" as alias).
Returns:
T.Buffer: A TVM buffer object allocated as a descriptor
"""
scope = "local.descriptor." + kind
# Buffer naming via `name` is not supported by this TVM builder signature;
# keep parameter for forward-compat, but do not pass it.
return T.alloc_buffer([1], dtype, scope=scope)
def alloc_wgmma_desc(dtype: str = "uint64"):
return alloc_descriptor("wgmma", dtype=dtype)
def alloc_tcgen05_smem_desc(dtype: str = "uint64"):
return alloc_descriptor("tcgen05_smem", dtype=dtype)
def alloc_tcgen05_instruction_desc(dtype: str = "uint32"):
return alloc_descriptor("tcgen05_instr", dtype=dtype)
# Alias: short name consistent with imports
def alloc_tcgen05_instr_desc(dtype: str = "uint32"):
return alloc_tcgen05_instruction_desc(dtype)
......@@ -1894,6 +1894,8 @@ ptx_mma = _dtype_forward(_tir_op.ptx_mma)
ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss)
ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs)
ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss)
ptx_tcgen05_mma_ts = _dtype_forward(_tir_op.ptx_tcgen05_mma_ts)
ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk)
......@@ -2145,6 +2147,7 @@ __all__ = [
"ptx_mma_sp",
"ptx_wgmma_ss",
"ptx_wgmma_rs",
"ptx_tcgen05_mma_ss",
"ptx_ldmatrix",
"ptx_cp_async",
"ptx_cp_async_bulk",
......
......@@ -6,8 +6,8 @@ from __future__ import annotations
import tilelang.language as T
from tvm import ir, tir
from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
from tilelang.utils.language import get_buffer_region_from_load
from tilelang.language.utils import buffer_region_to_tile_region, buffer_load_to_tile_region
from tilelang.utils.language import get_buffer_region_from_load, legalize_pairwise_extents
_MEMORY_ORDER_ID_MAP = {
"relaxed": 0,
......@@ -201,13 +201,14 @@ def atomic_add(dst: Buffer,
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(src_extent, dst_extent)
src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent)
def _to_region(data, access_type):
def _to_region(data, access_type, extent):
if isinstance(data, tir.Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, tir.Buffer):
return buffer_to_tile_region(data, access_type)
zeros = [tir.IntImm("int32", 0) for _ in extent]
return buffer_load_to_tile_region(tir.BufferLoad(data, zeros), access_type, extent)
elif isinstance(data, tir.BufferRegion):
return buffer_region_to_tile_region(data, access_type, extent)
elif isinstance(data, tir.BufferLoad):
......@@ -218,8 +219,8 @@ def atomic_add(dst: Buffer,
else:
return buffer_load_to_tile_region(data, access_type, extent)
value = _to_region(value, "r")
dst = _to_region(dst, "w")
value = _to_region(value, "r", src_extent)
dst = _to_region(dst, "w", dst_extent)
# Note: tile-region-based atomic operations don't support return_prev yet
# This would need to be implemented in the tile runtime
......
......@@ -5,9 +5,10 @@ from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tilelang.utils.target import check_hip_availability
from tvm import tir
from tvm import DataType, tir
from tvm.runtime import convert
from typing import Any
from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad
from tvm.tir import PrimExpr, Var, Call, BufferLoad, BufferRegion
_IS_HIP_AVAILABLE = check_hip_availability()
......@@ -429,6 +430,168 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent)
def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
offset: int | PrimExpr = 0,
num_regs: int | PrimExpr | None = None,
dtype: str | None = None):
"""Insert a warpgroup fence for the destination accumulator registers.
This prevents NVCC from sinking uses of accumulator fragments past the corresponding
WGMMA operations by issuing an empty inline assembly barrier on every register.
Args:
buffer_or_ptr: Buffer | BufferLoad | BufferRegion | PrimExpr
A buffer representing the accumulator fragment, a buffer load/region
that identifies a starting element within the fragment, or a pointer expression
(e.g., tvm_access_ptr/address_of/typed Var).
offset: int | PrimExpr
Element offset from the start of the accumulator fragment.
num_regs: int | PrimExpr | None
Number of 32-bit registers to fence. If None and a Buffer is provided, it will be
derived from the buffer shape and dtype.
dtype: str | None
Data type string of the accumulator elements. When passing a buffer or
buffer-derived expression, dtype is inferred. It is required only when
passing a raw pointer expression that cannot be inferred.
Returns:
tir.Call: A handle to the warpgroup fence operation.
"""
if isinstance(buffer_or_ptr, BufferLoad):
# Treat BufferLoad as a request to fence starting from the loaded element's address
buf = buffer_or_ptr.buffer
data_ptr = buf.data
inferred_dtype = buf.dtype
if dtype is not None and dtype != inferred_dtype:
raise ValueError(f"dtype mismatch: provided {dtype}, buffer uses {inferred_dtype}.")
dtype = inferred_dtype
# Compute element offset from indices using strides if present, otherwise row-major
if len(buf.strides) == len(buf.shape) and len(buf.strides) > 0:
elem_off = 0
for idx, stride in zip(buffer_or_ptr.indices, buf.strides):
elem_off = elem_off + idx * stride
else:
elem_off = 0
stride_acc = 1
for idx, dim in zip(reversed(buffer_or_ptr.indices), reversed(buf.shape)):
elem_off = elem_off + idx * stride_acc
stride_acc = stride_acc * dim
# Combine with user-provided offset
offset = elem_off + convert(offset)
if num_regs is None:
raise ValueError("num_regs must be provided when passing a BufferLoad.")
return evaluate(
tir.call_intrin(
"handle",
tir.op.Op.get("tl.warpgroup_fence_operand"),
dtype,
data_ptr,
convert(offset),
convert(num_regs),
))
if isinstance(buffer_or_ptr, tir.Buffer):
data_ptr = buffer_or_ptr.data
inferred_dtype = buffer_or_ptr.dtype
if dtype is not None and dtype != inferred_dtype:
raise ValueError(f"dtype mismatch: provided {dtype}, buffer uses {inferred_dtype}.")
dtype = inferred_dtype
if num_regs is None:
total_elems = 1
for dim in buffer_or_ptr.shape:
if isinstance(dim, tir.IntImm):
total_elems *= int(dim)
else:
raise ValueError(
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic.")
bits_per_elem = DataType(dtype).bits
num_regs = (total_elems * bits_per_elem + 31) // 32
elif isinstance(buffer_or_ptr, BufferRegion):
buf = buffer_or_ptr.buffer
data_ptr = buf.data
inferred_dtype = buf.dtype
if dtype is not None and dtype != inferred_dtype:
raise ValueError(f"dtype mismatch: provided {dtype}, buffer uses {inferred_dtype}.")
dtype = inferred_dtype
# Compute element offset from region min using strides if present, otherwise row-major
if len(buf.strides) == len(buf.shape) and len(buf.strides) > 0:
elem_off = 0
for r, stride in zip(buffer_or_ptr.region, buf.strides):
elem_off = elem_off + r.min * stride
else:
elem_off = 0
stride_acc = 1
for r, dim in zip(reversed(buffer_or_ptr.region), reversed(buf.shape)):
elem_off = elem_off + r.min * stride_acc
stride_acc = stride_acc * dim
# Combine with user-provided offset
offset = elem_off + convert(offset)
# Try derive num_regs from region extents if fully static; otherwise require user input
if num_regs is None:
total_elems = 1
static = True
for r in buffer_or_ptr.region:
if isinstance(r.extent, tir.IntImm):
total_elems *= int(r.extent)
else:
static = False
break
if static:
bits_per_elem = DataType(dtype).bits
num_regs = (total_elems * bits_per_elem + 31) // 32
else:
raise ValueError(
"warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic."
)
return evaluate(
tir.call_intrin(
"handle",
tir.op.Op.get("tl.warpgroup_fence_operand"),
dtype,
data_ptr,
convert(offset),
convert(num_regs),
))
else:
data_ptr = buffer_or_ptr
# Try to infer dtype from common pointer expressions when not provided
if dtype is None:
inferred = None
# Case 1: Pointer from Buffer.access_ptr -> tir.builtin.tvm_access_ptr
if isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.tvm_access_ptr()):
# args[0] is a type annotation call; its dtype carries the element dtype
inferred = str(data_ptr.args[0].dtype)
# Case 2: Pointer from tir.address_of(BufferLoad(...))
elif isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.address_of()):
# args[0] should be a BufferLoad; its dtype is the element dtype
inferred = str(data_ptr.args[0].dtype)
# Case 3: Typed pointer Var with PrimType element (typed TIR)
elif hasattr(data_ptr, "type_annotation") and data_ptr.type_annotation is not None:
try:
elem_ty = getattr(data_ptr.type_annotation, "element_type", None)
if elem_ty is not None and hasattr(elem_ty, "dtype"):
inferred = str(elem_ty.dtype)
except Exception:
inferred = None
if inferred is None:
raise ValueError(
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
dtype = inferred
if num_regs is None:
raise ValueError("num_regs must be provided when passing a pointer expression.")
return evaluate(
tir.call_intrin(
"handle",
tir.op.Op.get("tl.warpgroup_fence_operand"),
dtype,
data_ptr,
convert(offset),
convert(num_regs),
))
def wait_wgmma(id: int):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
......@@ -537,38 +700,70 @@ def sync_grid():
return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid"))
def initialize_descriptor(descriptor: Buffer,
start_address: PrimExpr,
layout_type_: int = 0,
leading_byte_offset: int = 0,
stride_byte_offset: int = 0) -> PrimExpr:
"""
Initialize a memory descriptor with the given parameters.
def initialize_wgmma_descriptor(
descriptor: tir.Buffer,
start_address: PrimExpr,
layout_type_: int = 0,
leading_byte_offset: int = 0,
stride_byte_offset: int = 0,
) -> PrimExpr:
"""Initialize a WGMMA/UTCMMA shared-memory descriptor."""
Parameters:
descriptor (Buffer): The memory descriptor to initialize.
start_address (PrimExpr): The starting address of the memory region.
layout_type_ (int, optional): Layout type identifier. Defaults to 0.
leading_byte_offset (int, optional): Leading byte offset. Defaults to 0.
stride_byte_offset (int, optional): Stride byte offset. Defaults to 0.
if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
Returns:
PrimExpr: A handle representing the initialized descriptor.
"""
if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or
descriptor.shape[0] != 1):
raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
descriptor, [0])
return evaluate(
tir.call_intrin(
"handle",
tir.op.Op.get("tl.initialize_wgmma_descriptor"),
descriptor,
start_address,
layout_type_,
int(leading_byte_offset),
int(stride_byte_offset),
))
def initialize_tcgen05_descriptor(
descriptor: tir.Buffer,
start_address: PrimExpr,
leading_byte_offset: int,
stride_byte_offset: int,
base_offset: int = 0,
leading_is_absolute: bool = False,
swizzle_mode: int = 0,
) -> PrimExpr:
"""Initialize a TCGEN05 shared-memory descriptor."""
if not isinstance(descriptor, (BufferLoad, Buffer)):
if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or
descriptor.shape[0] != 1):
raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
descriptor, [0])
return evaluate(
tir.call_intrin("handle", tir.op.Op.get("tl.initialize_descriptor"), descriptor,
start_address, layout_type_, int(leading_byte_offset),
int(stride_byte_offset)))
tir.call_intrin(
"handle",
tir.op.Op.get("tl.initialize_tcgen05_descriptor"),
descriptor,
start_address,
int(leading_byte_offset),
int(stride_byte_offset),
int(base_offset),
tir.IntImm("int32", 1 if leading_is_absolute else 0),
int(swizzle_mode),
))
def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr:
......@@ -582,10 +777,11 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
Returns:
PrimExpr: A handle representing the modified descriptor.
"""
if not isinstance(descriptor, (BufferLoad, Buffer)):
if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
if isinstance(descriptor, tir.Buffer) and len(
descriptor.shape) != 1 or descriptor.shape[0] != 1:
raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
......@@ -606,3 +802,113 @@ def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
def tcgen05_mma_arrive(mbar_ptr):
"""Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.
Parameters
----------
mbar_ptr : PrimExpr
Pointer to the mbarrier object in shared memory (e.g., Barrier*).
"""
return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)
def ptx_mma_sm70(
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
):
"""TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta).
This intrinsic provides SM70-specific MMA operations that support m16n16k4 shape
with FP16 inputs and FP16/FP32 accumulation.
Parameters
----------
shape : str
The shape of mma fragment (e.g., "m16n16k4").
A_layout : str
The layout of multiplicand fragment A ("row" or "col").
B_layout : str
The layout of multiplicand fragment B ("row" or "col").
A_dtype : str
The data type of multiplicand fragment A (typically "fp16").
B_dtype : str
The data type of multiplicand fragment B (typically "fp16").
C_dtype : str
The data type of accumulator fragment C ("fp16" or "fp32").
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment B.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
Returns
-------
call : PrimExpr
The call expression.
Examples
--------
>>> T.ptx_mma_sm70(
... "float16",
... "m16n16k4",
... "row",
... "col",
... "fp16",
... "fp16",
... "fp16",
... A_local.data,
... 0,
... B_local.data,
... 0,
... C_local.data,
... 0,
... )
"""
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.ptx_mma_sm70"),
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
)
......@@ -3,9 +3,12 @@ from __future__ import annotations
from typing import Literal
from tilelang import language as T
from tilelang.utils.language import get_buffer_region_from_load
from tilelang.utils.language import (
get_buffer_region_from_load,
legalize_pairwise_extents,
)
from tvm import ir, tir
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
from tilelang.language.utils import buffer_region_to_tile_region, buffer_load_to_tile_region
def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
......@@ -55,15 +58,26 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
return tir.BufferStore(dst.buffer, src, dst.indices)
assert src_extent or dst_extent, "Can't deduce copy extents from args"
# Treat missing extent as length-matched ones to enable broadcasting logic.
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(src_extent, dst_extent)
def _to_region(data, access_type):
# Align and broadcast extents from the right (tail) side independently
# for src and dst, so we can pass them unchanged into _to_region.
# Rules per-dim from the right:
# - equal -> keep both
# - one is 1 -> set that side to the other side's dim
# - otherwise -> error
src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent)
def _to_region(data, access_type, extent):
if isinstance(data, tir.Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, tir.Buffer):
return buffer_to_tile_region(data, access_type)
# Restrict a raw buffer to the computed copy extent by creating
# a BufferLoad at origin and passing the extents explicitly.
zeros = [tir.IntImm("int32", 0) for _ in extent]
return buffer_load_to_tile_region(tir.BufferLoad(data, zeros), access_type, extent)
elif isinstance(data, tir.BufferRegion):
return buffer_region_to_tile_region(data, access_type, extent)
elif isinstance(data, tir.BufferLoad):
......@@ -74,8 +88,9 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
else:
return buffer_load_to_tile_region(data, access_type, extent)
src = _to_region(src, "r")
dst = _to_region(dst, "w")
# Use legalized extents for src and dst respectively.
src = _to_region(src, "r", src_extent)
dst = _to_region(dst, "w", dst_extent)
if coalesced_width is None:
coalesced_width = -1 # PrimExpr can not be None
......
......@@ -4,9 +4,14 @@ from __future__ import annotations
from tvm import tir
from tilelang.language import has_let_value, get_let_value
from tilelang.utils.language import get_buffer_region_from_load
from tilelang.language.utils import (
buffer_to_tile_region,
buffer_region_to_tile_region,
buffer_load_to_tile_region,
)
def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.PrimExpr):
"""Fill a buffer or buffer region with a specified value.
Args:
......@@ -16,9 +21,30 @@ def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
Returns:
A TVM intrinsic call that performs the fill operation
"""
# Normalize Var with let value to its underlying object
if isinstance(buffer, tir.Var) and has_let_value(buffer):
buffer = get_let_value(buffer)
# Convert to a tl.region descriptor (PrimExpr) with write access
region_call = None
if isinstance(buffer, tir.Buffer):
buffer = buffer.access_ptr("w") # Get write pointer if input is a Buffer
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value)
region_call = buffer_to_tile_region(buffer, "w")
elif isinstance(buffer, tir.BufferRegion):
extents = [r.extent for r in buffer.region]
region_call = buffer_region_to_tile_region(buffer, "w", extents)
elif isinstance(buffer, tir.BufferLoad):
region = get_buffer_region_from_load(buffer)
if region is not None:
extents = [r.extent for r in region.region]
region_call = buffer_region_to_tile_region(region, "w", extents)
else:
# Fallback: treat element access as 1-extent per dim
region_call = buffer_load_to_tile_region(buffer, "w", [1] * len(buffer.indices))
else:
# As-is fallback (rare): pass through for downstream handling
region_call = buffer
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), region_call, value)
def clear(buffer: tir.Buffer | tir.Var):
......
......@@ -4,10 +4,19 @@ from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
from tilelang.utils.language import get_buffer_region_from_load
def gemm(
from tilelang.utils.language import (
to_buffer_region,
retrieve_shape,
retrieve_stride,
retrieve_ptr,
retrieve_offset,
prim_expr_equal,
)
from tilelang.env import env as _env
def _gemm_impl(
op_key: str,
A: tir.Buffer | tir.Var,
B: tir.Buffer | tir.Var,
C: tir.Buffer | tir.Var,
......@@ -19,30 +28,9 @@ def gemm(
wg_wait: int = 0,
mbar: tir.Buffer | None = None,
):
"""Perform a General Matrix Multiplication (GEMM) operation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
"""Shared GEMM implementation.
Args:
A (Union[tir.Buffer, tir.Var]): First input matrix
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
On hopper it is equivalent to `wgmma.wait_group.sync.aligned <wg_wait>` if wg_wait is not -1
On sm100, `wg_wait` can only be 0 or -1. `mbarrier_wait(TCGEN5MMA barrier)` will be appended if wg_wait is 0.
mbar (tir.Buffer, optional): mbarrier for TCGEN5MMA synchronization
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
Returns a call_intrin handle for the given op key.
"""
def legalize_arguments(arg: tir.Buffer | tir.Var):
......@@ -63,52 +51,10 @@ def gemm(
C = legalize_arguments(C)
mbar = legalize_arguments(mbar) if mbar is not None else None
def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
return object.shape
elif isinstance(object, tir.BufferRegion):
region = object.region
shape = []
for r in region:
shape.append(r.extent)
return shape
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
shape = []
for r in region:
shape.append(r.extent)
return shape
else:
raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
for s in reversed(object.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferRegion):
buffer, _ = object.buffer, object.region
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
else:
raise ValueError(
f"Unsupported retrieve_stride argument type: {type(object)} for buffer {object}")
# Normalize A/B/C to BufferRegion to pass into tl.gemm
A = to_buffer_region(A)
B = to_buffer_region(B)
C = to_buffer_region(C)
A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
......@@ -132,68 +78,11 @@ def gemm(
M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
assert prim_expr_equal(K, K_B), f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
stride_a = A_stride[-2]
stride_b = B_stride[-2]
def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
buffer, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
# not offset the last two dimension
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
elif isinstance(object, tir.BufferRegion):
_, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
return indices
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
return indices
else:
raise ValueError(
f"Unsupported retrieve_offset argument type: {type(object)} for buffer {object}")
A_offset = retrieve_offset(A)
B_offset = retrieve_offset(B)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
......@@ -201,18 +90,15 @@ def gemm(
offset_a = A_offset[-1]
offset_b = B_offset[-1]
Aptr = retrieve_ptr(A, "r")
Bptr = retrieve_ptr(B, "r")
Cptr = retrieve_ptr(C, "rw")
mbarptr = retrieve_ptr(mbar, "rw") if mbar is not None else tir.const(0, "uint32")
C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0]
return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), Aptr, Bptr, Cptr, transpose_A,
transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a,
offset_b, k_pack, wg_wait, mbarptr, C_coords[0], C_coords[1])
C_coords = [r.min for r in C.region]
return tir.call_intrin("handle", tir.op.Op.get(op_key), A, B, C, transpose_A, transpose_B, M, N,
K, policy, clear_accum, stride_a, stride_b, offset_a, offset_b, k_pack,
wg_wait, mbarptr, C_coords[0], C_coords[1])
# experimental currently, for fast compilation
def gemm_v2(
# Public wrappers
def gemm_v1(
A: tir.Buffer | tir.Var,
B: tir.Buffer | tir.Var,
C: tir.Buffer | tir.Var,
......@@ -222,205 +108,52 @@ def gemm_v2(
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
mbar: tir.Buffer | None = None,
):
"""Perform a General Matrix Multiplication (GEMM) operation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
A (Union[tir.Buffer, tir.Var]): First input matrix
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""
def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers.
Args:
arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
Returns:
Union[tir.Buffer, tir.Var]: The legalized argument
"""
if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer
return arg
A = legalize_arguments(A)
B = legalize_arguments(B)
C = legalize_arguments(C)
def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
return object.shape
elif isinstance(object, tir.BufferRegion):
region = object.region
shape = []
for r in region:
shape.append(r.extent)
return shape
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
shape = []
for r in region:
shape.append(r.extent)
return shape
else:
raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
for s in reversed(object.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferRegion):
buffer, _ = object.buffer, object.region
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
else:
raise ValueError(
f"Unsupported retrieve_stride argument type: {type(object)} for buffer {object}")
A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
C_shape = retrieve_shape(C)
A_stride = retrieve_stride(A)
B_stride = retrieve_stride(B)
assert len(C_shape) == 2, "current only support C as a 2D tensor"
assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2:
for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
if len(B_shape) > 2:
for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
stride_a = A_stride[-2]
stride_b = B_stride[-2]
def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
buffer, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
# not offset the last two dimension
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
elif isinstance(object, tir.BufferRegion):
_, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
return indices
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
return indices
else:
raise ValueError(
f"Unsupported retrieve_offset argument type: {type(object)} for buffer {object}")
"""GEMM v1: use op tl.gemm."""
return _gemm_impl(
"tl.gemm",
A,
B,
C,
transpose_A,
transpose_B,
policy,
clear_accum,
k_pack,
wg_wait,
mbar,
)
A_offset = retrieve_offset(A)
B_offset = retrieve_offset(B)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
offset_a = A_offset[-1]
offset_b = B_offset[-1]
Aptr = retrieve_ptr(A, "r")
Bptr = retrieve_ptr(B, "r")
Cptr = retrieve_ptr(C, "rw")
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm_py"),
Aptr,
Bptr,
Cptr,
# experimental currently, for fast compilation
def gemm_v2(
A: tir.Buffer | tir.Var,
B: tir.Buffer | tir.Var,
C: tir.Buffer | tir.Var,
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
mbar: tir.Buffer | None = None,
):
"""GEMM v2: use op tl.gemm_py."""
return _gemm_impl(
"tl.gemm_py",
A,
B,
C,
transpose_A,
transpose_B,
M,
N,
K,
policy,
clear_accum,
stride_a,
stride_b,
offset_a,
offset_b,
k_pack,
wg_wait,
mbar,
)
# Default to v2; allow forcing v1 via environment variable
gemm = gemm_v1 if _env.use_gemm_v1() else gemm_v2
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Any
from tvm import tir
from tvm.tir import IntImm
import tvm.script.ir_builder.tir as tb_tir
from .v2.builder import SerialForWithStep
from tilelang import _ffi_api
def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
coalesced_width : Optional[int]
The coalesced width of the parallel loop.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
annotations: dict[str, Any] = {}
if coalesced_width is not None:
annotations.update({"coalesced_width": coalesced_width})
return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
def Persistent(
domain: list[tir.PrimExpr],
wave_size: tir.PrimExpr,
index: tir.PrimExpr,
group_size: tir.PrimExpr | None = 8,
):
"""Tools to construct persistent for loop.
Parameters
----------
domain : List[tir.PrimExpr]
The list of dominators.
wave_size : int
The wave size.
index : int
The tile index in one wave.
group_size : tir.PrimExpr
The group size.
"""
return _ffi_api.Persistent(domain, wave_size, index, group_size)
def Pipelined(
start: tir.PrimExpr,
stop: tir.PrimExpr = None,
......@@ -44,3 +92,20 @@ def Pipelined(
group = []
# type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group)
def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1
if step is None or step_is_one:
return tb_tir.serial(start, stop, annotations=annotations)
else:
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Any
from tvm import tir
from tilelang import _ffi_api
def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
coalesced_width : Optional[int]
The coalesced width of the parallel loop.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
annotations: dict[str, Any] = {}
if coalesced_width is not None:
annotations.update({"coalesced_width": coalesced_width})
return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
"""The language interface for tl programs."""
from __future__ import annotations
from tvm import tir
from tilelang import _ffi_api
def Persistent(
domain: list[tir.PrimExpr],
wave_size: tir.PrimExpr,
index: tir.PrimExpr,
group_size: tir.PrimExpr | None = 8,
):
"""Tools to construct persistent for loop.
Parameters
----------
domain : List[tir.PrimExpr]
The list of dominators.
wave_size : int
The wave size.
index : int
The tile index in one wave.
group_size : tir.PrimExpr
The group size.
"""
return _ffi_api.Persistent(domain, wave_size, index, group_size)
......@@ -5,6 +5,7 @@ It includes functionality to print variables, print values in buffers, condition
from tvm import tir
from typing import Any
import tilelang.language as T
from tilelang.language.kernel import get_thread_bindings
from tilelang.language import copy, macro, serial, alloc_shared
from tilelang.language.utils import index_to_coordinates
......@@ -148,10 +149,10 @@ def device_assert(condition: tir.PrimExpr, msg: str = ""):
"""
if _IS_CUDA_AVAILABLE:
if msg == "":
tir.call_extern("void", "device_assert", condition)
T.call_intrin("void", tir.op.Op.get("tl.device_assert"), condition)
else:
warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2)
tir.call_extern("void", "device_assert_with_msg", condition, msg)
T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, msg)
def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Any, Sequence, SupportsIndex, TYPE_CHECKING
from typing import Any, SupportsIndex, TYPE_CHECKING
from collections.abc import Sequence
from typing_extensions import Self
from tvm import tir
......
......@@ -2,7 +2,10 @@
from __future__ import annotations
from tvm import tir
from tilelang.language import copy, macro, alloc_shared
from tilelang.language import copy, macro, alloc_shared, alloc_fragment
from tilelang.language.utils import buffer_to_tile_region
from tilelang.utils.language import is_shared, is_fragment
from tvm.script.ir_builder import IRBuilder
def _legalize_dim(buffer: tir.Buffer, dim: int):
......@@ -34,17 +37,70 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
raise ValueError(
f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, "
f"output shape is {out.shape}, expected shapes are {expected_shapes_str}")
buffer = buffer.access_ptr("r")
out = out.access_ptr("w")
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer,
out,
reduce_type,
dim,
clear,
)
@macro
def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
if is_shared(buffer) and is_shared(out):
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
red_frag_out = alloc_fragment(out.shape, out.dtype)
# rename buffers
IRBuilder.name(buffer.name + "_frag", red_frag_in)
IRBuilder.name(out.name + "_frag", red_frag_out)
copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer_to_tile_region(red_frag_in, "r"),
buffer_to_tile_region(red_frag_out, "w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
elif is_shared(buffer) and is_fragment(out):
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
IRBuilder.name(buffer.name + "_frag", red_frag_in)
copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer_to_tile_region(red_frag_in, "r"),
buffer_to_tile_region(out, "w"),
reduce_type,
dim,
clear,
)
elif is_fragment(buffer) and is_shared(out):
red_frag_out = alloc_fragment(out.shape, out.dtype)
IRBuilder.name(out.name + "_frag", red_frag_out)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer_to_tile_region(buffer, "r"),
buffer_to_tile_region(red_frag_out, "w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
elif is_fragment(buffer) and is_fragment(out):
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer_to_tile_region(buffer, "r"),
buffer_to_tile_region(out, "w"),
reduce_type,
dim,
clear,
)
else:
raise ValueError(f"Invalid buffer scopes: {buffer.scope()} and {out.scope()}")
return reduce_macro(buffer, out, reduce_type, dim, clear)
def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
......
......@@ -7,7 +7,6 @@ from tilelang.utils import deprecated
__all__ = ["dynamic", "symbolic"]
@deprecated("T.dynamic(...)", "tir.Var(...)", "v0.1.9")
def dynamic(name: str, dtype: str = "int32"):
"""
Create a TIR dynamic symbolic variable.
......@@ -22,7 +21,7 @@ def dynamic(name: str, dtype: str = "int32"):
return tir.Var(name, dtype)
@deprecated("T.symbolic(...)", "T.dynamic(...)")
@deprecated("T.symbolic(...)", "T.dynamic(...)", "v0.1.9")
def symbolic(name: str, dtype: str = "int32"):
"""Deprecated alias for `T.dynamic`."""
return tir.Var(name, dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment