Commit 3f294650 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Improve documentation and add detailed docstrings across multiple modules (#298)

* [Enhancement] Update AtomicAdd functions for BFLOAT16 in common.h

- Added conditional compilation for BFLOAT16 atomic operations to ensure compatibility with CUDA architectures greater than 7.5.
- Improved code clarity by organizing the AtomicAdd functions and adding relevant comments for better understanding.

* [Enhancement] Improve documentation and add detailed docstrings across multiple modules

- Updated the `__init__.py` file to enhance module documentation, providing clarity on auto-tuning functionalities.
- Added comprehensive docstrings to the `JITContext`, `AutotuneResult`, and `AutoTuner` classes, detailing their attributes and methods.
- Enhanced memory allocation utilities in `allocate.py` with detailed descriptions for each allocation function.
- Improved documentation for various intrinsic operations in `builtin.py`, `copy.py`, `customize.py`, `frame.py`, `gemm.py`, `memscope.py`, and `reduce.py`, ensuring clear explanations of parameters and return values.
- Refactored the `KernelCache` class to improve clarity and maintainability, including detailed comments and docstrings for methods.
- Overall, these changes aim to enhance code readability and provide better guidance for future developers and users of the Tile-AI framework.
parent 9ad9d9cd
"""The auto-tune module for tilelang programs.""" """The auto-tune module for tilelang programs.
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
"""
import tilelang import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
...@@ -22,6 +26,19 @@ logging.basicConfig( ...@@ -22,6 +26,19 @@ logging.basicConfig(
@dataclass(frozen=True) @dataclass(frozen=True)
class JITContext: class JITContext:
"""Context object for Just-In-Time compilation settings.
Attributes:
out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism.
ref_prog: Reference program for correctness validation.
rtol: Relative tolerance for output validation.
atol: Absolute tolerance for output validation.
max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
skip_check: Whether to skip validation checks.
profiler: Profiler instance for performance measurement.
target: Target platform ('cuda' or 'hip').
"""
out_idx: List[int] out_idx: List[int]
supply_type: tilelang.TensorSupplyType supply_type: tilelang.TensorSupplyType
ref_prog: Callable ref_prog: Callable
...@@ -35,6 +52,16 @@ class JITContext: ...@@ -35,6 +52,16 @@ class JITContext:
@dataclass(frozen=True) @dataclass(frozen=True)
class AutotuneResult: class AutotuneResult:
"""Results from auto-tuning process.
Attributes:
latency: Best achieved execution latency.
config: Configuration that produced the best result.
ref_latency: Reference implementation latency.
libcode: Generated library code.
func: Optimized function.
kernel: Compiled kernel function.
"""
latency: float latency: float
config: dict config: dict
ref_latency: float ref_latency: float
...@@ -44,6 +71,15 @@ class AutotuneResult: ...@@ -44,6 +71,15 @@ class AutotuneResult:
class AutoTuner: class AutoTuner:
"""Auto-tuner for tilelang programs.
This class handles the auto-tuning process by testing different configurations
and finding the optimal parameters for program execution.
Args:
fn: The function to be auto-tuned.
configs: List of configurations to try during auto-tuning.
"""
def __init__(self, fn: Callable, configs): def __init__(self, fn: Callable, configs):
self.fn = fn self.fn = fn
...@@ -54,6 +90,15 @@ class AutoTuner: ...@@ -54,6 +90,15 @@ class AutoTuner:
@classmethod @classmethod
def from_kernel(cls, kernel: Callable, configs): def from_kernel(cls, kernel: Callable, configs):
"""Create an AutoTuner instance from a kernel function.
Args:
kernel: The kernel function to auto-tune.
configs: List of configurations to try.
Returns:
AutoTuner: A new AutoTuner instance.
"""
return cls(kernel, configs) return cls(kernel, configs)
def set_compile_args(self, def set_compile_args(self,
...@@ -65,6 +110,21 @@ class AutoTuner: ...@@ -65,6 +110,21 @@ class AutoTuner:
max_mismatched_ratio: float = 0.01, max_mismatched_ratio: float = 0.01,
skip_check: bool = False, skip_check: bool = False,
target: Literal['auto', 'cuda', 'hip'] = 'auto'): target: Literal['auto', 'cuda', 'hip'] = 'auto'):
"""Set compilation arguments for the auto-tuner.
Args:
out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism.
ref_prog: Reference program for validation.
rtol: Relative tolerance for validation.
atol: Absolute tolerance for validation.
max_mismatched_ratio: Maximum allowed mismatch ratio.
skip_check: Whether to skip validation.
target: Target platform.
Returns:
AutoTuner: Self for method chaining.
"""
def _compile(*config_arg): def _compile(*config_arg):
kernel = tilelang.compile(self.fn(*config_arg), out_idx=out_idx, target=target) kernel = tilelang.compile(self.fn(*config_arg), out_idx=out_idx, target=target)
...@@ -85,6 +145,16 @@ class AutoTuner: ...@@ -85,6 +145,16 @@ class AutoTuner:
return self return self
def run(self, warmup: int = 25, rep: int = 100, timeout: int = 100): def run(self, warmup: int = 25, rep: int = 100, timeout: int = 100):
"""Run the auto-tuning process.
Args:
warmup: Number of warmup iterations.
rep: Number of repetitions for timing.
timeout: Maximum time per configuration.
Returns:
AutotuneResult: Results of the auto-tuning process.
"""
sig = inspect.signature(self.fn) sig = inspect.signature(self.fn)
keys = list(sig.parameters.keys()) keys = list(sig.parameters.keys())
bound_args = sig.bind() bound_args = sig.bind()
...@@ -192,12 +262,25 @@ class AutoTuner: ...@@ -192,12 +262,25 @@ class AutoTuner:
kernel=best_jit_context.profiler.func) kernel=best_jit_context.profiler.func)
def __call__(self) -> Any: def __call__(self) -> Any:
"""Make the AutoTuner callable, running the auto-tuning process.
Returns:
AutotuneResult: Results of the auto-tuning process.
"""
return self.run() return self.run()
def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> Callable: def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> Callable:
""" """Decorator for auto-tuning tilelang programs.
Decorator for tilelang program
Args:
configs: Configuration space to explore during auto-tuning.
warmup: Number of warmup iterations before timing.
rep: Number of repetitions for timing measurements.
timeout: Maximum time (in seconds) allowed for each configuration.
Returns:
Callable: Decorated function that performs auto-tuning.
""" """
def decorator(fn: Callable) -> AutoTuner: def decorator(fn: Callable) -> AutoTuner:
...@@ -217,6 +300,21 @@ def jit(out_idx: List[int], ...@@ -217,6 +300,21 @@ def jit(out_idx: List[int],
max_mismatched_ratio: float = 0.01, max_mismatched_ratio: float = 0.01,
skip_check: bool = False, skip_check: bool = False,
target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable: target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:
"""Just-In-Time compilation decorator for tilelang programs.
Args:
out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism.
ref_prog: Reference program for correctness validation.
rtol: Relative tolerance for output validation.
atol: Absolute tolerance for output validation.
max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
skip_check: Whether to skip validation checks.
target: Target platform ('auto', 'cuda', or 'hip').
Returns:
Callable: Decorated function that performs JIT compilation.
"""
def wrapper(fn: Callable): def wrapper(fn: Callable):
......
...@@ -30,31 +30,38 @@ class KernelCache: ...@@ -30,31 +30,38 @@ class KernelCache:
kernel_lib.so: The compiled kernel library kernel_lib.so: The compiled kernel library
params.pkl: The compiled kernel parameters params.pkl: The compiled kernel parameters
""" """
_instance = None # For implementing singleton pattern _instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
def __new__(cls, cache_dir=TILELANG_CACHE_DIR): def __new__(cls, cache_dir=TILELANG_CACHE_DIR):
"""Singleton pattern to ensure only one KernelCache instance""" """Singleton pattern to ensure only one KernelCache instance"""
with cls._lock: if cls._instance is None:
if cls._instance is None: with cls._lock:
cls._instance = super(KernelCache, cls).__new__(cls) if cls._instance is None: # 双重检查锁定
cls._instance.cache_dir = cache_dir # Cache directory instance = super().__new__(cls)
os.makedirs(cls._instance.cache_dir, exist_ok=True) # Ensure cache directory exists instance.cache_dir = cache_dir
cls._instance.logger = logging.getLogger(__name__) # Initialize logger os.makedirs(instance.cache_dir, exist_ok=True)
cls._instance.logger.setLevel(
logging.ERROR) # Set default logging level to ERROR, can be adjusted instance.logger = logging.getLogger(__name__)
instance.logger.setLevel(logging.ERROR)
cls._instance = instance
return cls._instance return cls._instance
def _generate_key(self, func: Callable, out_idx: List[int], def _generate_key(
execution_backend: Literal["dlpack", "ctypes", "cython"], args, self,
target: Union[str, Target], target_host: Union[str, Target]) -> str: func: Callable,
""" out_idx: List[int],
Generates a unique cache key. execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
""" args=None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
) -> str:
func_binary = cloudpickle.dumps(func.script()) func_binary = cloudpickle.dumps(func.script())
key_data = { key_data = {
"func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key "func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key
"out_idx": tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx], "out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]),
"args_repr": tuple( "args_repr": tuple(
repr(arg) for arg in args repr(arg) for arg in args
), # Use repr to serialize arguments, may need more robust serialization ), # Use repr to serialize arguments, may need more robust serialization
...@@ -100,7 +107,13 @@ class KernelCache: ...@@ -100,7 +107,13 @@ class KernelCache:
pass_configs=pass_configs, pass_configs=pass_configs,
) )
key = self._generate_key(func, out_idx, execution_backend, args, target, target_host) key = self._generate_key(
func=func,
out_idx=out_idx,
execution_backend=execution_backend,
args=args,
target=target,
target_host=target_host)
with self._lock: # TODO: use filelock with self._lock: # TODO: use filelock
# Attempt to load from disk # Attempt to load from disk
kernel = self._load_kernel_from_disk(key, target, target_host, out_idx, kernel = self._load_kernel_from_disk(key, target, target_host, out_idx,
...@@ -122,8 +135,15 @@ class KernelCache: ...@@ -122,8 +135,15 @@ class KernelCache:
self.logger.warning("DLPack backend does not support cache saving to disk.") self.logger.warning("DLPack backend does not support cache saving to disk.")
else: else:
with self._lock: # enter critical section again to check and update disk cache with self._lock: # enter critical section again to check and update disk cache
disk_kernel = self._load_kernel_from_disk(key, target, target_host, out_idx, disk_kernel = self._load_kernel_from_disk(
execution_backend, pass_configs, func) key,
target,
target_host,
out_idx,
execution_backend,
pass_configs,
func,
)
if disk_kernel is None: if disk_kernel is None:
self._save_kernel_to_disk(key, kernel, func) self._save_kernel_to_disk(key, kernel, func)
return kernel return kernel
...@@ -180,14 +200,16 @@ class KernelCache: ...@@ -180,14 +200,16 @@ class KernelCache:
except Exception as e: except Exception as e:
self.logger.error(f"Error saving kernel parameters to disk: {e}") self.logger.error(f"Error saving kernel parameters to disk: {e}")
def _load_kernel_from_disk(self, def _load_kernel_from_disk(
key: str, self,
target: Union[str, Target] = "auto", key: str,
target_host: Union[str, Target] = None, target: Union[str, Target] = "auto",
out_idx: List[int] = None, target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", out_idx: List[int] = None,
pass_configs: dict = None, execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
func: Callable = None) -> JITKernel: pass_configs: dict = None,
func: Callable = None,
) -> JITKernel:
""" """
Loads kernel from disk. Loads kernel from disk.
""" """
......
"""The language interface for tl programs.""" """Memory allocation utilities for Tile-AI programs.
This module provides a set of functions for allocating different types of memory buffers
in Tile-AI programs. It wraps TVM's buffer allocation functionality with convenient
interfaces for different memory scopes.
Available allocation functions:
- alloc_shared: Allocates shared memory buffers for inter-thread communication
- alloc_local: Allocates local memory buffers for thread-private storage
- alloc_fragment: Allocates fragment memory buffers for specialized operations
- alloc_var: Allocates single-element variable buffers
Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope.
"""
from tvm.script import tir as T from tvm.script import tir as T
def alloc_shared(shape, dtype, scope="shared.dyn"): def alloc_shared(shape, dtype, scope="shared.dyn"):
"""Allocate a shared memory buffer for inter-thread communication.
Args:
shape (tuple): The shape of the buffer to allocate
dtype (str): The data type of the buffer (e.g., 'float32', 'int32')
scope (str, optional): The memory scope. Defaults to "shared.dyn"
Returns:
T.Buffer: A TVM buffer object allocated in shared memory
"""
return T.alloc_buffer(shape, dtype, scope=scope) return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_local(shape, dtype, scope="local"): def alloc_local(shape, dtype, scope="local"):
"""Allocate a local memory buffer for thread-private storage.
Args:
shape (tuple): The shape of the buffer to allocate
dtype (str): The data type of the buffer (e.g., 'float32', 'int32')
scope (str, optional): The memory scope. Defaults to "local"
Returns:
T.Buffer: A TVM buffer object allocated in local memory
"""
return T.alloc_buffer(shape, dtype, scope=scope) return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_fragment(shape, dtype, scope="local.fragment"): def alloc_fragment(shape, dtype, scope="local.fragment"):
"""Allocate a fragment memory buffer for specialized operations.
Args:
shape (tuple): The shape of the buffer to allocate
dtype (str): The data type of the buffer (e.g., 'float32', 'int32')
scope (str, optional): The memory scope. Defaults to "local.fragment"
Returns:
T.Buffer: A TVM buffer object allocated in fragment memory
"""
return T.alloc_buffer(shape, dtype, scope=scope) return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_var(dtype, scope="local.var"): def alloc_var(dtype, scope="local.var"):
"""Allocate a single-element variable buffer.
Args:
dtype (str): The data type of the buffer (e.g., 'float32', 'int32')
scope (str, optional): The memory scope. Defaults to "local.var"
Returns:
T.Buffer: A TVM buffer object allocated as a single-element variable
"""
return T.alloc_buffer([1], dtype, scope=scope) return T.alloc_buffer([1], dtype, scope=scope)
...@@ -4,48 +4,144 @@ from tvm import tir ...@@ -4,48 +4,144 @@ from tvm import tir
def CreateListofMBarrierOp(*args): def CreateListofMBarrierOp(*args):
"""Create a list of memory barrier operations.
Args:
*args: Variable arguments passed to the memory barrier creation operation
Returns:
tir.Call: A handle to the created list of memory barriers
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.CreateListofMBarrierOp"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.CreateListofMBarrierOp"), *args)
def GetMBarrierOp(*args): def GetMBarrierOp(*args):
"""Retrieve a memory barrier operation.
Args:
*args: Variable arguments to specify which memory barrier to retrieve
Returns:
tir.Call: A handle to the requested memory barrier
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.GetMBarrierOp"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.GetMBarrierOp"), *args)
def CreateTMADescriptorOp(*args): def CreateTMADescriptorOp(*args):
"""Create a Tensor Memory Access (TMA) descriptor.
Args:
*args: Variable arguments defining the TMA descriptor configuration
Returns:
tir.Call: A handle to the created TMA descriptor
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.CreateTMADescriptorOp"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.CreateTMADescriptorOp"), *args)
def TMALoadOp(*args): def TMALoadOp(*args):
"""Perform a Tensor Memory Access (TMA) load operation.
Args:
*args: Variable arguments specifying the TMA load parameters
Returns:
tir.Call: A handle to the TMA load operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.TMALoadOp"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.TMALoadOp"), *args)
def FenceProxyAsyncOp(*args): def FenceProxyAsyncOp(*args):
"""Create a fence for asynchronous proxy operations.
Args:
*args: Variable arguments for fence configuration
Returns:
tir.Call: A handle to the fence operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.FenceProxyAsyncOp"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.FenceProxyAsyncOp"), *args)
def TMAStoreArrive(*args): def TMAStoreArrive(*args):
"""Signal the arrival of a TMA store operation.
Args:
*args: Variable arguments for the store arrival operation
Returns:
tir.Call: A handle to the store arrive operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreArrive"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreArrive"), *args)
def TMAStoreWait(*args): def TMAStoreWait(*args):
"""Wait for completion of TMA store operations.
Args:
*args: Variable arguments specifying which store operations to wait for
Returns:
tir.Call: A handle to the store wait operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreWait"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreWait"), *args)
def SetMaxNReg(*args): def SetMaxNReg(*args):
"""Set the maximum number of registers to use.
Args:
*args: Variable arguments specifying register allocation limits
Returns:
tir.Call: A handle to the register setting operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.SetMaxNReg"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.SetMaxNReg"), *args)
def NoSetMaxNReg(*args): def NoSetMaxNReg(*args):
"""Disable the maximum register limit setting.
Args:
*args: Variable arguments for the operation
Returns:
tir.Call: A handle to the register limit disable operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.NoSetMaxNReg"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.NoSetMaxNReg"), *args)
def MBarrierWaitParity(*args): def MBarrierWaitParity(*args):
"""Wait for memory barrier parity condition.
Args:
*args: Variable arguments specifying the parity wait condition
Returns:
tir.Call: A handle to the barrier wait operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.MBarrierWaitParity"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.MBarrierWaitParity"), *args)
def MBarrierExpectTX(*args): def MBarrierExpectTX(*args):
"""Set expected transaction count for memory barrier.
Args:
*args: Variable arguments specifying the expected transaction count
Returns:
tir.Call: A handle to the barrier expectation operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.MBarrierExpectTX"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.MBarrierExpectTX"), *args)
def WaitWgmma(*args): def WaitWgmma(*args):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
Args:
*args: Variable arguments specifying which operations to wait for
Returns:
tir.Call: A handle to the WGMMA wait operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.WaitWgmma"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.WaitWgmma"), *args)
...@@ -6,21 +6,59 @@ from tvm import ir, tir ...@@ -6,21 +6,59 @@ from tvm import ir, tir
def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr): def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr):
"""Create a memory region descriptor for tile operations.
Args:
buffer (tir.BufferLoad): The buffer to create a region for
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
*args (tir.PrimExpr): Extent expressions defining the region size
Returns:
tir.Call: A region descriptor for tile operations
"""
access_type = {"r": 1, "w": 2, "rw": 3}[access_type] access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
return tir.call_intrin("handle", tir.op.Op.get("tl.region"), buffer, access_type, *args) return tir.call_intrin("handle", tir.op.Op.get("tl.region"), buffer, access_type, *args)
def buffer_to_tile_region(buffer: tir.Buffer, access_type: str): def buffer_to_tile_region(buffer: tir.Buffer, access_type: str):
"""Convert a TVM buffer to a tile region descriptor.
Args:
buffer (tir.Buffer): The buffer to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor covering the entire buffer
"""
mins = [0 for _ in buffer.shape] mins = [0 for _ in buffer.shape]
extents = [x for x in buffer.shape] extents = [x for x in buffer.shape]
return region(T.BufferLoad(buffer, mins), access_type, *extents) return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region(load: tir.BufferLoad, access_type: str, extents: List[tir.PrimExpr]): def buffer_load_to_tile_region(load: tir.BufferLoad, access_type: str, extents: List[tir.PrimExpr]):
"""Convert a buffer load operation to a tile region descriptor.
Args:
load (tir.BufferLoad): The buffer load operation
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
extents (List[tir.PrimExpr]): List of expressions defining the region size
Returns:
tir.Call: A region descriptor for the loaded area
"""
return region(load, access_type, *extents) return region(load, access_type, *extents)
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str): def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str):
"""Convert a buffer region to a tile region descriptor.
Args:
buffer_region (tir.BufferRegion): The buffer region to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor for the specified buffer region
"""
mins = [x.min for x in buffer_region.region] mins = [x.min for x in buffer_region.region]
extents = [x.extent for x in buffer_region.region] extents = [x.extent for x in buffer_region.region]
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *extents) return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *extents)
...@@ -31,6 +69,19 @@ def copy( ...@@ -31,6 +69,19 @@ def copy(
dst: Union[tir.Buffer, tir.BufferLoad], dst: Union[tir.Buffer, tir.BufferLoad],
coalesced_width: Optional[int] = None, coalesced_width: Optional[int] = None,
): ):
"""Copy data between memory regions.
Args:
src (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]): Source memory region
dst (Union[tir.Buffer, tir.BufferLoad]): Destination memory region
coalesced_width (Optional[int], optional): Width for coalesced memory access. Defaults to None.
Raises:
TypeError: If copy extents cannot be deduced from arguments
Returns:
tir.Call: A handle to the copy operation
"""
if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer): if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer):
ir.assert_structural_equal(src.shape, dst.shape) ir.assert_structural_equal(src.shape, dst.shape)
...@@ -82,6 +133,21 @@ def c2d_im2col( ...@@ -82,6 +133,21 @@ def c2d_im2col(
dilation: int, dilation: int,
pad: int, pad: int,
): ):
"""Perform im2col transformation for 2D convolution.
Args:
img (tir.Buffer): Input image buffer
col (tir.Buffer): Output column buffer
nhw_step (tir.PrimExpr): Step size for batch and spatial dimensions
c_step (tir.PrimExpr): Step size for channel dimension
kernel (int): Kernel size
stride (int): Stride of the convolution
dilation (int): Dilation rate
pad (int): Padding size
Returns:
tir.Call: A handle to the im2col operation
"""
return tir.call_intrin( return tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.c2d_im2col"), tir.op.Op.get("tl.c2d_im2col"),
......
...@@ -6,14 +6,42 @@ from typing import List, Union ...@@ -6,14 +6,42 @@ from typing import List, Union
def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr: def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr:
"""Perform an atomic addition operation.
Args:
dst (Buffer): Destination buffer where the atomic addition will be performed
value (PrimExpr): Value to be atomically added
Returns:
PrimExpr: Handle to the atomic addition operation
"""
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value) return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
def atomic_addx2(dst: Buffer, value: PrimExpr) -> PrimExpr: def atomic_addx2(dst: Buffer, value: PrimExpr) -> PrimExpr:
"""Perform an atomic addition operation with double-width operands.
Args:
dst (Buffer): Destination buffer where the atomic addition will be performed
value (PrimExpr): Value to be atomically added (double-width)
Returns:
PrimExpr: Handle to the double-width atomic addition operation
"""
return T.call_extern("handle", "AtomicAddx2", T.address_of(dst), T.address_of(value)) return T.call_extern("handle", "AtomicAddx2", T.address_of(dst), T.address_of(value))
def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr: def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
"""Perform a 4-element dot product with accumulation (DP4A).
Args:
A (Buffer): First input buffer
B (Buffer): Second input buffer
C (Buffer): Accumulation buffer
Returns:
PrimExpr: Handle to the DP4A operation
"""
return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C)) return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C))
...@@ -37,8 +65,11 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer: ...@@ -37,8 +65,11 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
"""Reshapes the input buffer to the specified shape. """Reshapes the input buffer to the specified shape.
Args: Args:
src: Input buffer to be reshaped src (Buffer): Input buffer to be reshaped
shape: New shape for the buffer shape (List[PrimExpr]): New shape for the buffer
Returns:
Buffer: A new buffer view with the specified shape
""" """
return T.Buffer(shape, src.dtype, src.data) return T.Buffer(shape, src.dtype, src.data)
...@@ -46,12 +77,15 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer: ...@@ -46,12 +77,15 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
def view(src: Buffer, def view(src: Buffer,
shape: Union[List[PrimExpr], None] = None, shape: Union[List[PrimExpr], None] = None,
dtype: Union[str, None] = None) -> Buffer: dtype: Union[str, None] = None) -> Buffer:
"""Views the input buffer to the specified shape. """Views the input buffer with optionally modified shape and dtype.
Args: Args:
src: Input buffer to be viewed src (Buffer): Input buffer to be viewed
shape: New shape for the buffer shape (Union[List[PrimExpr], None], optional): New shape for the buffer. Defaults to None.
dtype: New dtype for the buffer dtype (Union[str, None], optional): New dtype for the buffer. Defaults to None.
Returns:
Buffer: A new buffer view with the specified shape and dtype
""" """
if shape is None: if shape is None:
shape = src.shape shape = src.shape
......
...@@ -10,47 +10,74 @@ from typing import Optional ...@@ -10,47 +10,74 @@ from typing import Optional
class FrameStack: class FrameStack:
""" """A stack-like container for managing TIR frame objects and their variable bindings.
A stack-like wrapper around a deque that provides push, pop, and top methods,
along with a var-value mapping functionality. This class implements a stack data structure using a deque and maintains a mapping
of variables to their values. It provides methods for stack operations and variable
value lookups.
""" """
def __init__(self): def __init__(self):
"""Initialize an empty frame stack and variable mapping."""
self._stack = deque() self._stack = deque()
self._var_value_map = {} self._var_value_map = {}
def push(self, item): def push(self, item):
"""Pushes an item onto the top of the stack.""" """Push an item onto the stack and update variable mapping if applicable.
Args:
item: The frame object to push onto the stack
"""
self._stack.append(item) self._stack.append(item)
# Store the var-value mapping if it's a LetFrame
if hasattr(item, 'var') and hasattr(item, 'value'): if hasattr(item, 'var') and hasattr(item, 'value'):
self._var_value_map[item.var] = item.value self._var_value_map[item.var] = item.value
def pop(self): def pop(self):
""" """Remove and return the top item from the stack.
Pops and returns the top of the stack, or returns None
if the stack is empty. Returns:
The top frame object from the stack
Raises:
IndexError: If the stack is empty
""" """
if self._stack: if self._stack:
item = self._stack.pop() item = self._stack.pop()
# Clean up the var-value mapping if it's a LetFrame
if hasattr(item, 'var'): if hasattr(item, 'var'):
self._var_value_map.pop(item.var, None) self._var_value_map.pop(item.var, None)
return item return item
raise IndexError(f"{self.__class__.__name__} is empty") raise IndexError(f"{self.__class__.__name__} is empty")
def get_value(self, var): def get_value(self, var):
"""Get the value associated with a variable.""" """Retrieve the value associated with a variable.
Args:
var: The variable to look up
Returns:
The value associated with the variable, or None if not found
"""
return self._var_value_map.get(var) return self._var_value_map.get(var)
def has_value(self, var): def has_value(self, var):
"""Check if a variable has an associated value.""" """Check if a variable has an associated value.
Args:
var: The variable to check
Returns:
bool: True if the variable has an associated value, False otherwise
"""
return var in self._var_value_map return var in self._var_value_map
def top(self): def top(self):
""" """Return the top item of the stack without removing it.
Returns the item on the top of the stack without removing it,
or None if the stack is empty. Returns:
The top frame object from the stack
Raises:
IndexError: If the stack is empty
""" """
if self._stack: if self._stack:
return self._stack[-1] return self._stack[-1]
...@@ -74,8 +101,18 @@ _let_frame_stack = FrameStack() ...@@ -74,8 +101,18 @@ _let_frame_stack = FrameStack()
@_register_object("script.ir_builder.tir.LetFrame") @_register_object("script.ir_builder.tir.LetFrame")
class LetFrame(TIRFrame): class LetFrame(TIRFrame):
"""A TIR frame for let bindings that manages variable scope and value tracking.
This frame type extends TIRFrame to provide variable binding functionality and
maintains a global stack of active bindings.
"""
def __enter__(self) -> Var: def __enter__(self) -> Var:
"""Enter the let frame scope and process buffer loads.
Returns:
Var: The variable bound in this frame
"""
super().__enter__() super().__enter__()
if isinstance(self.value, BufferLoad): if isinstance(self.value, BufferLoad):
indices = self.value.indices indices = self.value.indices
...@@ -92,44 +129,73 @@ class LetFrame(TIRFrame): ...@@ -92,44 +129,73 @@ class LetFrame(TIRFrame):
return self.var return self.var
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
"""Exit the let frame scope and clean up the stack.
Args:
ptype: Exception type if an exception occurred
value: Exception value if an exception occurred
trace: Exception traceback if an exception occurred
"""
if _let_frame_stack.top() is self: if _let_frame_stack.top() is self:
_let_frame_stack.pop() _let_frame_stack.pop()
super().__exit__(ptype, value, trace) super().__exit__(ptype, value, trace)
@classmethod @classmethod
def Current(cls) -> "LetFrame": def Current(cls) -> "LetFrame":
""" """Get the current (topmost) let frame.
Returns the topmost (current) LetFrame from the stack if it exists,
or raises IndexError if the stack is empty. Returns:
LetFrame: The current let frame
Raises:
IndexError: If there are no active let frames
""" """
return _let_frame_stack.top() return _let_frame_stack.top()
@staticmethod @staticmethod
def get_value(var: Var): def get_value(var: Var):
""" """Get the value bound to a variable in any active frame.
Get the value associated with a variable.
Returns None if the variable is not found. Args:
var (Var): The variable to look up
Returns:
The value bound to the variable, or None if not found
""" """
return _let_frame_stack.get_value(var) return _let_frame_stack.get_value(var)
@staticmethod @staticmethod
def has_value(var: Var) -> bool: def has_value(var: Var) -> bool:
""" """Check if a variable has a binding in any active frame.
Check if a variable has an associated value.
Args:
var (Var): The variable to check
Returns:
bool: True if the variable has a binding, False otherwise
""" """
return _let_frame_stack.has_value(var) return _let_frame_stack.has_value(var)
def has_let_value(var: Var) -> bool: def has_let_value(var: Var) -> bool:
""" """Check if a variable has a binding in the current let frame stack.
Check if a variable has an associated value in the let frame stack.
Args:
var (Var): The variable to check
Returns:
bool: True if the variable has a binding, False otherwise
""" """
return _let_frame_stack.has_value(var) return _let_frame_stack.has_value(var)
def get_let_value(var: Var) -> Optional[PrimExpr]: def get_let_value(var: Var) -> Optional[PrimExpr]:
""" """Get the value bound to a variable in the current let frame stack.
Get the value associated with a variable from the let frame stack.
Returns None if the variable is not found. Args:
var (Var): The variable to look up
Returns:
Optional[PrimExpr]: The bound value if found, None otherwise
""" """
return _let_frame_stack.get_value(var) return _let_frame_stack.get_value(var)
...@@ -17,17 +17,41 @@ def gemm( ...@@ -17,17 +17,41 @@ def gemm(
k_pack: int = 1, k_pack: int = 1,
wg_wait: int = 0, wg_wait: int = 0,
): ):
""" """Perform a General Matrix Multiplication (GEMM) operation.
k_pack: int
The number of k dimension that is packed into a single warp. This function computes C = A @ B where A and B can optionally be transposed.
please ref to mfma macro generator for the detail information. The operation supports various warp policies and accumulation modes.
Args:
A (Union[tir.Buffer, tir.Var]): First input matrix
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
""" """
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
"""Convert let-bound variables to their corresponding buffers.
Args:
arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
Returns:
Union[tir.Buffer, tir.Var]: The legalized argument
"""
if isinstance(arg, tir.Var) and T.has_let_value(arg): if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer return T.get_let_value(arg).buffer
else: return arg
return arg
A = legalize_arguments(A) A = legalize_arguments(A)
B = legalize_arguments(B) B = legalize_arguments(B)
......
...@@ -4,6 +4,11 @@ from tvm.ir import make_node ...@@ -4,6 +4,11 @@ from tvm.ir import make_node
@register_func("tvm.info.mem.local.var") @register_func("tvm.info.mem.local.var")
def mem_info_local_var(): def mem_info_local_var():
"""Get memory information for local variable memory.
Returns:
tvm.ir.make_node: A node containing memory information
"""
return make_node( return make_node(
"MemoryInfo", "MemoryInfo",
unit_bits=8, unit_bits=8,
......
...@@ -4,6 +4,18 @@ from tvm import tir ...@@ -4,6 +4,18 @@ from tvm import tir
def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
"""Perform a reduction operation on a buffer along a specified dimension.
Args:
buffer (tir.Buffer): Input buffer to reduce
out (tir.Buffer): Output buffer to store results
reduce_type (str): Type of reduction ('max', 'min', 'sum', 'abssum')
dim (int): Dimension along which to perform reduction
clear (bool): Whether to initialize the output buffer before reduction
Returns:
tir.Call: Handle to the reduction operation
"""
buffer = buffer.access_ptr("r") buffer = buffer.access_ptr("r")
out = out.access_ptr("w") out = out.access_ptr("w")
return tir.call_intrin( return tir.call_intrin(
...@@ -38,12 +50,43 @@ def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True ...@@ -38,12 +50,43 @@ def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True): def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
"""Perform reduce min on input buffer, store the result to output buffer.
Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
clear (bool, optional): If True, output buffer will be initialized to inf. Defaults to True.
Returns:
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "min", dim, clear) return reduce(buffer, out, "min", dim, clear)
def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int): def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
"""Perform reduce sum on input buffer, store the result to output buffer.
Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
Returns:
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "sum", dim, True) return reduce(buffer, out, "sum", dim, True)
def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int): def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
"""Perform reduce absolute sum on input buffer, store the result to output buffer.
Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
Returns:
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "abssum", dim, True) return reduce(buffer, out, "abssum", dim, True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment