Commit 3f294650 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Improve documentation and add detailed docstrings across multiple modules (#298)

* [Enhancement] Update AtomicAdd functions for BFLOAT16 in common.h

- Added conditional compilation for BFLOAT16 atomic operations to ensure compatibility with CUDA architectures greater than 7.5.
- Improved code clarity by organizing the AtomicAdd functions and adding relevant comments for better understanding.

* [Enhancement] Improve documentation and add detailed docstrings across multiple modules

- Updated the `__init__.py` file to enhance module documentation, providing clarity on auto-tuning functionalities.
- Added comprehensive docstrings to the `JITContext`, `AutotuneResult`, and `AutoTuner` classes, detailing their attributes and methods.
- Enhanced memory allocation utilities in `allocate.py` with detailed descriptions for each allocation function.
- Improved documentation for various intrinsic operations in `builtin.py`, `copy.py`, `customize.py`, `frame.py`, `gemm.py`, `memscope.py`, and `reduce.py`, ensuring clear explanations of parameters and return values.
- Refactored the `KernelCache` class to improve clarity and maintainability, including detailed comments and docstrings for methods.
- Overall, these changes aim to enhance code readability and provide better guidance for future developers and users of the Tile-AI framework.
parent 9ad9d9cd
"""The auto-tune module for tilelang programs."""
"""The auto-tune module for tilelang programs.
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
"""
import tilelang
from tilelang import tvm as tvm
......@@ -22,6 +26,19 @@ logging.basicConfig(
@dataclass(frozen=True)
class JITContext:
"""Context object for Just-In-Time compilation settings.
Attributes:
out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism.
ref_prog: Reference program for correctness validation.
rtol: Relative tolerance for output validation.
atol: Absolute tolerance for output validation.
max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
skip_check: Whether to skip validation checks.
profiler: Profiler instance for performance measurement.
target: Target platform ('cuda' or 'hip').
"""
out_idx: List[int]
supply_type: tilelang.TensorSupplyType
ref_prog: Callable
......@@ -35,6 +52,16 @@ class JITContext:
@dataclass(frozen=True)
class AutotuneResult:
"""Results from auto-tuning process.
Attributes:
latency: Best achieved execution latency.
config: Configuration that produced the best result.
ref_latency: Reference implementation latency.
libcode: Generated library code.
func: Optimized function.
kernel: Compiled kernel function.
"""
latency: float
config: dict
ref_latency: float
......@@ -44,6 +71,15 @@ class AutotuneResult:
class AutoTuner:
"""Auto-tuner for tilelang programs.
This class handles the auto-tuning process by testing different configurations
and finding the optimal parameters for program execution.
Args:
fn: The function to be auto-tuned.
configs: List of configurations to try during auto-tuning.
"""
def __init__(self, fn: Callable, configs):
self.fn = fn
......@@ -54,6 +90,15 @@ class AutoTuner:
@classmethod
def from_kernel(cls, kernel: Callable, configs):
"""Create an AutoTuner instance from a kernel function.
Args:
kernel: The kernel function to auto-tune.
configs: List of configurations to try.
Returns:
AutoTuner: A new AutoTuner instance.
"""
return cls(kernel, configs)
def set_compile_args(self,
......@@ -65,6 +110,21 @@ class AutoTuner:
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
target: Literal['auto', 'cuda', 'hip'] = 'auto'):
"""Set compilation arguments for the auto-tuner.
Args:
out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism.
ref_prog: Reference program for validation.
rtol: Relative tolerance for validation.
atol: Absolute tolerance for validation.
max_mismatched_ratio: Maximum allowed mismatch ratio.
skip_check: Whether to skip validation.
target: Target platform.
Returns:
AutoTuner: Self for method chaining.
"""
def _compile(*config_arg):
kernel = tilelang.compile(self.fn(*config_arg), out_idx=out_idx, target=target)
......@@ -85,6 +145,16 @@ class AutoTuner:
return self
def run(self, warmup: int = 25, rep: int = 100, timeout: int = 100):
"""Run the auto-tuning process.
Args:
warmup: Number of warmup iterations.
rep: Number of repetitions for timing.
timeout: Maximum time per configuration.
Returns:
AutotuneResult: Results of the auto-tuning process.
"""
sig = inspect.signature(self.fn)
keys = list(sig.parameters.keys())
bound_args = sig.bind()
......@@ -192,12 +262,25 @@ class AutoTuner:
kernel=best_jit_context.profiler.func)
def __call__(self) -> Any:
"""Make the AutoTuner callable, running the auto-tuning process.
Returns:
AutotuneResult: Results of the auto-tuning process.
"""
return self.run()
def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> Callable:
"""
Decorator for tilelang program
"""Decorator for auto-tuning tilelang programs.
Args:
configs: Configuration space to explore during auto-tuning.
warmup: Number of warmup iterations before timing.
rep: Number of repetitions for timing measurements.
timeout: Maximum time (in seconds) allowed for each configuration.
Returns:
Callable: Decorated function that performs auto-tuning.
"""
def decorator(fn: Callable) -> AutoTuner:
......@@ -217,6 +300,21 @@ def jit(out_idx: List[int],
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:
"""Just-In-Time compilation decorator for tilelang programs.
Args:
out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism.
ref_prog: Reference program for correctness validation.
rtol: Relative tolerance for output validation.
atol: Absolute tolerance for output validation.
max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
skip_check: Whether to skip validation checks.
target: Target platform ('auto', 'cuda', or 'hip').
Returns:
Callable: Decorated function that performs JIT compilation.
"""
def wrapper(fn: Callable):
......
......@@ -30,31 +30,38 @@ class KernelCache:
kernel_lib.so: The compiled kernel library
params.pkl: The compiled kernel parameters
"""
_instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety
def __new__(cls, cache_dir=TILELANG_CACHE_DIR):
"""Singleton pattern to ensure only one KernelCache instance"""
with cls._lock:
if cls._instance is None:
cls._instance = super(KernelCache, cls).__new__(cls)
cls._instance.cache_dir = cache_dir # Cache directory
os.makedirs(cls._instance.cache_dir, exist_ok=True) # Ensure cache directory exists
cls._instance.logger = logging.getLogger(__name__) # Initialize logger
cls._instance.logger.setLevel(
logging.ERROR) # Set default logging level to ERROR, can be adjusted
with cls._lock:
if cls._instance is None: # 双重检查锁定
instance = super().__new__(cls)
instance.cache_dir = cache_dir
os.makedirs(instance.cache_dir, exist_ok=True)
instance.logger = logging.getLogger(__name__)
instance.logger.setLevel(logging.ERROR)
cls._instance = instance
return cls._instance
def _generate_key(self, func: Callable, out_idx: List[int],
execution_backend: Literal["dlpack", "ctypes", "cython"], args,
target: Union[str, Target], target_host: Union[str, Target]) -> str:
"""
Generates a unique cache key.
"""
def _generate_key(
self,
func: Callable,
out_idx: List[int],
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
args=None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
) -> str:
func_binary = cloudpickle.dumps(func.script())
key_data = {
"func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key
"out_idx": tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx],
"out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]),
"args_repr": tuple(
repr(arg) for arg in args
), # Use repr to serialize arguments, may need more robust serialization
......@@ -100,7 +107,13 @@ class KernelCache:
pass_configs=pass_configs,
)
key = self._generate_key(func, out_idx, execution_backend, args, target, target_host)
key = self._generate_key(
func=func,
out_idx=out_idx,
execution_backend=execution_backend,
args=args,
target=target,
target_host=target_host)
with self._lock: # TODO: use filelock
# Attempt to load from disk
kernel = self._load_kernel_from_disk(key, target, target_host, out_idx,
......@@ -122,8 +135,15 @@ class KernelCache:
self.logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock: # enter critical section again to check and update disk cache
disk_kernel = self._load_kernel_from_disk(key, target, target_host, out_idx,
execution_backend, pass_configs, func)
disk_kernel = self._load_kernel_from_disk(
key,
target,
target_host,
out_idx,
execution_backend,
pass_configs,
func,
)
if disk_kernel is None:
self._save_kernel_to_disk(key, kernel, func)
return kernel
......@@ -180,14 +200,16 @@ class KernelCache:
except Exception as e:
self.logger.error(f"Error saving kernel parameters to disk: {e}")
def _load_kernel_from_disk(self,
def _load_kernel_from_disk(
self,
key: str,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: List[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None,
func: Callable = None) -> JITKernel:
func: Callable = None,
) -> JITKernel:
"""
Loads kernel from disk.
"""
......
"""The language interface for tl programs."""
"""Memory allocation utilities for Tile-AI programs.
This module provides a set of functions for allocating different types of memory buffers
in Tile-AI programs. It wraps TVM's buffer allocation functionality with convenient
interfaces for different memory scopes.
Available allocation functions:
- alloc_shared: Allocates shared memory buffers for inter-thread communication
- alloc_local: Allocates local memory buffers for thread-private storage
- alloc_fragment: Allocates fragment memory buffers for specialized operations
- alloc_var: Allocates single-element variable buffers
Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope.
"""
from tvm.script import tir as T
def alloc_shared(shape, dtype, scope="shared.dyn"):
"""Allocate a shared memory buffer for inter-thread communication.
Args:
shape (tuple): The shape of the buffer to allocate
dtype (str): The data type of the buffer (e.g., 'float32', 'int32')
scope (str, optional): The memory scope. Defaults to "shared.dyn"
Returns:
T.Buffer: A TVM buffer object allocated in shared memory
"""
return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_local(shape, dtype, scope="local"):
"""Allocate a local memory buffer for thread-private storage.
Args:
shape (tuple): The shape of the buffer to allocate
dtype (str): The data type of the buffer (e.g., 'float32', 'int32')
scope (str, optional): The memory scope. Defaults to "local"
Returns:
T.Buffer: A TVM buffer object allocated in local memory
"""
return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_fragment(shape, dtype, scope="local.fragment"):
"""Allocate a fragment memory buffer for specialized operations.
Args:
shape (tuple): The shape of the buffer to allocate
dtype (str): The data type of the buffer (e.g., 'float32', 'int32')
scope (str, optional): The memory scope. Defaults to "local.fragment"
Returns:
T.Buffer: A TVM buffer object allocated in fragment memory
"""
return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_var(dtype, scope="local.var"):
"""Allocate a single-element variable buffer.
Args:
dtype (str): The data type of the buffer (e.g., 'float32', 'int32')
scope (str, optional): The memory scope. Defaults to "local.var"
Returns:
T.Buffer: A TVM buffer object allocated as a single-element variable
"""
return T.alloc_buffer([1], dtype, scope=scope)
......@@ -4,48 +4,144 @@ from tvm import tir
def CreateListofMBarrierOp(*args):
"""Create a list of memory barrier operations.
Args:
*args: Variable arguments passed to the memory barrier creation operation
Returns:
tir.Call: A handle to the created list of memory barriers
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.CreateListofMBarrierOp"), *args)
def GetMBarrierOp(*args):
"""Retrieve a memory barrier operation.
Args:
*args: Variable arguments to specify which memory barrier to retrieve
Returns:
tir.Call: A handle to the requested memory barrier
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.GetMBarrierOp"), *args)
def CreateTMADescriptorOp(*args):
"""Create a Tensor Memory Access (TMA) descriptor.
Args:
*args: Variable arguments defining the TMA descriptor configuration
Returns:
tir.Call: A handle to the created TMA descriptor
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.CreateTMADescriptorOp"), *args)
def TMALoadOp(*args):
"""Perform a Tensor Memory Access (TMA) load operation.
Args:
*args: Variable arguments specifying the TMA load parameters
Returns:
tir.Call: A handle to the TMA load operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.TMALoadOp"), *args)
def FenceProxyAsyncOp(*args):
"""Create a fence for asynchronous proxy operations.
Args:
*args: Variable arguments for fence configuration
Returns:
tir.Call: A handle to the fence operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.FenceProxyAsyncOp"), *args)
def TMAStoreArrive(*args):
"""Signal the arrival of a TMA store operation.
Args:
*args: Variable arguments for the store arrival operation
Returns:
tir.Call: A handle to the store arrive operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreArrive"), *args)
def TMAStoreWait(*args):
"""Wait for completion of TMA store operations.
Args:
*args: Variable arguments specifying which store operations to wait for
Returns:
tir.Call: A handle to the store wait operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreWait"), *args)
def SetMaxNReg(*args):
"""Set the maximum number of registers to use.
Args:
*args: Variable arguments specifying register allocation limits
Returns:
tir.Call: A handle to the register setting operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.SetMaxNReg"), *args)
def NoSetMaxNReg(*args):
"""Disable the maximum register limit setting.
Args:
*args: Variable arguments for the operation
Returns:
tir.Call: A handle to the register limit disable operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.NoSetMaxNReg"), *args)
def MBarrierWaitParity(*args):
"""Wait for memory barrier parity condition.
Args:
*args: Variable arguments specifying the parity wait condition
Returns:
tir.Call: A handle to the barrier wait operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.MBarrierWaitParity"), *args)
def MBarrierExpectTX(*args):
"""Set expected transaction count for memory barrier.
Args:
*args: Variable arguments specifying the expected transaction count
Returns:
tir.Call: A handle to the barrier expectation operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.MBarrierExpectTX"), *args)
def WaitWgmma(*args):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
Args:
*args: Variable arguments specifying which operations to wait for
Returns:
tir.Call: A handle to the WGMMA wait operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.WaitWgmma"), *args)
......@@ -6,21 +6,59 @@ from tvm import ir, tir
def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr):
"""Create a memory region descriptor for tile operations.
Args:
buffer (tir.BufferLoad): The buffer to create a region for
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
*args (tir.PrimExpr): Extent expressions defining the region size
Returns:
tir.Call: A region descriptor for tile operations
"""
access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
return tir.call_intrin("handle", tir.op.Op.get("tl.region"), buffer, access_type, *args)
def buffer_to_tile_region(buffer: tir.Buffer, access_type: str):
"""Convert a TVM buffer to a tile region descriptor.
Args:
buffer (tir.Buffer): The buffer to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor covering the entire buffer
"""
mins = [0 for _ in buffer.shape]
extents = [x for x in buffer.shape]
return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region(load: tir.BufferLoad, access_type: str, extents: List[tir.PrimExpr]):
"""Convert a buffer load operation to a tile region descriptor.
Args:
load (tir.BufferLoad): The buffer load operation
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
extents (List[tir.PrimExpr]): List of expressions defining the region size
Returns:
tir.Call: A region descriptor for the loaded area
"""
return region(load, access_type, *extents)
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str):
"""Convert a buffer region to a tile region descriptor.
Args:
buffer_region (tir.BufferRegion): The buffer region to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor for the specified buffer region
"""
mins = [x.min for x in buffer_region.region]
extents = [x.extent for x in buffer_region.region]
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *extents)
......@@ -31,6 +69,19 @@ def copy(
dst: Union[tir.Buffer, tir.BufferLoad],
coalesced_width: Optional[int] = None,
):
"""Copy data between memory regions.
Args:
src (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]): Source memory region
dst (Union[tir.Buffer, tir.BufferLoad]): Destination memory region
coalesced_width (Optional[int], optional): Width for coalesced memory access. Defaults to None.
Raises:
TypeError: If copy extents cannot be deduced from arguments
Returns:
tir.Call: A handle to the copy operation
"""
if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer):
ir.assert_structural_equal(src.shape, dst.shape)
......@@ -82,6 +133,21 @@ def c2d_im2col(
dilation: int,
pad: int,
):
"""Perform im2col transformation for 2D convolution.
Args:
img (tir.Buffer): Input image buffer
col (tir.Buffer): Output column buffer
nhw_step (tir.PrimExpr): Step size for batch and spatial dimensions
c_step (tir.PrimExpr): Step size for channel dimension
kernel (int): Kernel size
stride (int): Stride of the convolution
dilation (int): Dilation rate
pad (int): Padding size
Returns:
tir.Call: A handle to the im2col operation
"""
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.c2d_im2col"),
......
......@@ -6,14 +6,42 @@ from typing import List, Union
def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr:
"""Perform an atomic addition operation.
Args:
dst (Buffer): Destination buffer where the atomic addition will be performed
value (PrimExpr): Value to be atomically added
Returns:
PrimExpr: Handle to the atomic addition operation
"""
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
def atomic_addx2(dst: Buffer, value: PrimExpr) -> PrimExpr:
"""Perform an atomic addition operation with double-width operands.
Args:
dst (Buffer): Destination buffer where the atomic addition will be performed
value (PrimExpr): Value to be atomically added (double-width)
Returns:
PrimExpr: Handle to the double-width atomic addition operation
"""
return T.call_extern("handle", "AtomicAddx2", T.address_of(dst), T.address_of(value))
def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
"""Perform a 4-element dot product with accumulation (DP4A).
Args:
A (Buffer): First input buffer
B (Buffer): Second input buffer
C (Buffer): Accumulation buffer
Returns:
PrimExpr: Handle to the DP4A operation
"""
return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C))
......@@ -37,8 +65,11 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
"""Reshapes the input buffer to the specified shape.
Args:
src: Input buffer to be reshaped
shape: New shape for the buffer
src (Buffer): Input buffer to be reshaped
shape (List[PrimExpr]): New shape for the buffer
Returns:
Buffer: A new buffer view with the specified shape
"""
return T.Buffer(shape, src.dtype, src.data)
......@@ -46,12 +77,15 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
def view(src: Buffer,
shape: Union[List[PrimExpr], None] = None,
dtype: Union[str, None] = None) -> Buffer:
"""Views the input buffer to the specified shape.
"""Views the input buffer with optionally modified shape and dtype.
Args:
src: Input buffer to be viewed
shape: New shape for the buffer
dtype: New dtype for the buffer
src (Buffer): Input buffer to be viewed
shape (Union[List[PrimExpr], None], optional): New shape for the buffer. Defaults to None.
dtype (Union[str, None], optional): New dtype for the buffer. Defaults to None.
Returns:
Buffer: A new buffer view with the specified shape and dtype
"""
if shape is None:
shape = src.shape
......
......@@ -10,47 +10,74 @@ from typing import Optional
class FrameStack:
"""
A stack-like wrapper around a deque that provides push, pop, and top methods,
along with a var-value mapping functionality.
"""A stack-like container for managing TIR frame objects and their variable bindings.
This class implements a stack data structure using a deque and maintains a mapping
of variables to their values. It provides methods for stack operations and variable
value lookups.
"""
def __init__(self):
"""Initialize an empty frame stack and variable mapping."""
self._stack = deque()
self._var_value_map = {}
def push(self, item):
"""Pushes an item onto the top of the stack."""
"""Push an item onto the stack and update variable mapping if applicable.
Args:
item: The frame object to push onto the stack
"""
self._stack.append(item)
# Store the var-value mapping if it's a LetFrame
if hasattr(item, 'var') and hasattr(item, 'value'):
self._var_value_map[item.var] = item.value
def pop(self):
"""
Pops and returns the top of the stack, or returns None
if the stack is empty.
"""Remove and return the top item from the stack.
Returns:
The top frame object from the stack
Raises:
IndexError: If the stack is empty
"""
if self._stack:
item = self._stack.pop()
# Clean up the var-value mapping if it's a LetFrame
if hasattr(item, 'var'):
self._var_value_map.pop(item.var, None)
return item
raise IndexError(f"{self.__class__.__name__} is empty")
def get_value(self, var):
"""Get the value associated with a variable."""
"""Retrieve the value associated with a variable.
Args:
var: The variable to look up
Returns:
The value associated with the variable, or None if not found
"""
return self._var_value_map.get(var)
def has_value(self, var):
"""Check if a variable has an associated value."""
"""Check if a variable has an associated value.
Args:
var: The variable to check
Returns:
bool: True if the variable has an associated value, False otherwise
"""
return var in self._var_value_map
def top(self):
"""
Returns the item on the top of the stack without removing it,
or None if the stack is empty.
"""Return the top item of the stack without removing it.
Returns:
The top frame object from the stack
Raises:
IndexError: If the stack is empty
"""
if self._stack:
return self._stack[-1]
......@@ -74,8 +101,18 @@ _let_frame_stack = FrameStack()
@_register_object("script.ir_builder.tir.LetFrame")
class LetFrame(TIRFrame):
"""A TIR frame for let bindings that manages variable scope and value tracking.
This frame type extends TIRFrame to provide variable binding functionality and
maintains a global stack of active bindings.
"""
def __enter__(self) -> Var:
"""Enter the let frame scope and process buffer loads.
Returns:
Var: The variable bound in this frame
"""
super().__enter__()
if isinstance(self.value, BufferLoad):
indices = self.value.indices
......@@ -92,44 +129,73 @@ class LetFrame(TIRFrame):
return self.var
def __exit__(self, ptype, value, trace):
"""Exit the let frame scope and clean up the stack.
Args:
ptype: Exception type if an exception occurred
value: Exception value if an exception occurred
trace: Exception traceback if an exception occurred
"""
if _let_frame_stack.top() is self:
_let_frame_stack.pop()
super().__exit__(ptype, value, trace)
@classmethod
def Current(cls) -> "LetFrame":
"""
Returns the topmost (current) LetFrame from the stack if it exists,
or raises IndexError if the stack is empty.
"""Get the current (topmost) let frame.
Returns:
LetFrame: The current let frame
Raises:
IndexError: If there are no active let frames
"""
return _let_frame_stack.top()
@staticmethod
def get_value(var: Var):
"""
Get the value associated with a variable.
Returns None if the variable is not found.
"""Get the value bound to a variable in any active frame.
Args:
var (Var): The variable to look up
Returns:
The value bound to the variable, or None if not found
"""
return _let_frame_stack.get_value(var)
@staticmethod
def has_value(var: Var) -> bool:
"""
Check if a variable has an associated value.
"""Check if a variable has a binding in any active frame.
Args:
var (Var): The variable to check
Returns:
bool: True if the variable has a binding, False otherwise
"""
return _let_frame_stack.has_value(var)
def has_let_value(var: Var) -> bool:
"""
Check if a variable has an associated value in the let frame stack.
"""Check if a variable has a binding in the current let frame stack.
Args:
var (Var): The variable to check
Returns:
bool: True if the variable has a binding, False otherwise
"""
return _let_frame_stack.has_value(var)
def get_let_value(var: Var) -> Optional[PrimExpr]:
"""
Get the value associated with a variable from the let frame stack.
Returns None if the variable is not found.
"""Get the value bound to a variable in the current let frame stack.
Args:
var (Var): The variable to look up
Returns:
Optional[PrimExpr]: The bound value if found, None otherwise
"""
return _let_frame_stack.get_value(var)
......@@ -17,16 +17,40 @@ def gemm(
k_pack: int = 1,
wg_wait: int = 0,
):
"""
k_pack: int
The number of k dimension that is packed into a single warp.
please ref to mfma macro generator for the detail information.
"""Perform a General Matrix Multiplication (GEMM) operation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
A (Union[tir.Buffer, tir.Var]): First input matrix
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
"""Convert let-bound variables to their corresponding buffers.
Args:
arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
Returns:
Union[tir.Buffer, tir.Var]: The legalized argument
"""
if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer
else:
return arg
A = legalize_arguments(A)
......
......@@ -4,6 +4,11 @@ from tvm.ir import make_node
@register_func("tvm.info.mem.local.var")
def mem_info_local_var():
"""Get memory information for local variable memory.
Returns:
tvm.ir.make_node: A node containing memory information
"""
return make_node(
"MemoryInfo",
unit_bits=8,
......
......@@ -4,6 +4,18 @@ from tvm import tir
def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
"""Perform a reduction operation on a buffer along a specified dimension.
Args:
buffer (tir.Buffer): Input buffer to reduce
out (tir.Buffer): Output buffer to store results
reduce_type (str): Type of reduction ('max', 'min', 'sum', 'abssum')
dim (int): Dimension along which to perform reduction
clear (bool): Whether to initialize the output buffer before reduction
Returns:
tir.Call: Handle to the reduction operation
"""
buffer = buffer.access_ptr("r")
out = out.access_ptr("w")
return tir.call_intrin(
......@@ -38,12 +50,43 @@ def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
"""Perform reduce min on input buffer, store the result to output buffer.
Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
clear (bool, optional): If True, output buffer will be initialized to inf. Defaults to True.
Returns:
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "min", dim, clear)
def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
"""Perform reduce sum on input buffer, store the result to output buffer.
Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
Returns:
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "sum", dim, True)
def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
"""Perform reduce absolute sum on input buffer, store the result to output buffer.
Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
Returns:
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "abssum", dim, True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment