Unverified Commit 5e529522 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Add ruff config to check for useless spaces (#807)

* update lint config

* Remove spaces for blank line

* update
parent 4d54854b
...@@ -152,7 +152,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -152,7 +152,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
def _process_dynamic_symbolic(self): def _process_dynamic_symbolic(self):
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension) Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
for runtime shape resolution. for runtime shape resolution.
""" """
...@@ -179,17 +179,17 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -179,17 +179,17 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
*ins: List[torch.Tensor], *ins: List[torch.Tensor],
stream: Optional[int] = None): stream: Optional[int] = None):
"""High-level wrapper for kernel execution. """High-level wrapper for kernel execution.
Handles: Handles:
1. Input validation 1. Input validation
2. Output tensor allocation 2. Output tensor allocation
3. Dynamic shape resolution 3. Dynamic shape resolution
4. CUDA stream management 4. CUDA stream management
Args: Args:
ins: Input PyTorch tensors ins: Input PyTorch tensors
stream: Optional CUDA stream for asynchronous execution stream: Optional CUDA stream for asynchronous execution
Returns: Returns:
Single tensor or list of tensors containing the kernel results Single tensor or list of tensors containing the kernel results
""" """
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# This file is modified from the original version, # This file is modified from the original version,
# which is part of the flashinfer project # which is part of the flashinfer project
# (https://github.com/flashinfer-ai/flashinfer). # (https://github.com/flashinfer-ai/flashinfer).
"""Library information. This is a standalone file that can be used to get various info. """Library information. This is a standalone file that can be used to get various info.
Modified from flashinfer Modified from flashinfer
""" """
......
...@@ -80,11 +80,11 @@ from .utils import index_to_coordinates # noqa: F401 ...@@ -80,11 +80,11 @@ from .utils import index_to_coordinates # noqa: F401
def symbolic(name: str, dtype: str = "int32"): def symbolic(name: str, dtype: str = "int32"):
""" """
Create a TIR symbolic variable. Create a TIR symbolic variable.
Parameters: Parameters:
name (str): Identifier for the variable in generated TIR. name (str): Identifier for the variable in generated TIR.
dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32". dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32".
Returns: Returns:
tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels. tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
""" """
...@@ -108,7 +108,7 @@ def annotate_layout(layout_map: Dict): ...@@ -108,7 +108,7 @@ def annotate_layout(layout_map: Dict):
Returns: Returns:
block_attr: a block attribute block_attr: a block attribute
Example: Example:
@T.prim_func @T.prim_func
def main( def main(
...@@ -149,7 +149,7 @@ def annotate_padding(padding_map: Dict): ...@@ -149,7 +149,7 @@ def annotate_padding(padding_map: Dict):
Returns: Returns:
block_attr: a block attribute block_attr: a block attribute
Example: Example:
@T.prim_func @T.prim_func
def main( def main(
......
...@@ -29,7 +29,7 @@ def create_list_of_mbarrier(*args: Any) -> Call: ...@@ -29,7 +29,7 @@ def create_list_of_mbarrier(*args: Any) -> Call:
------ ------
TypeError TypeError
If the input is not a list or variadic arguments. If the input is not a list or variadic arguments.
Examples Examples
-------- --------
>>> create_list_of_mbarrier([128, 128]) >>> create_list_of_mbarrier([128, 128])
......
...@@ -20,18 +20,18 @@ _MEMORY_ORDER_ID_MAP = { ...@@ -20,18 +20,18 @@ _MEMORY_ORDER_ID_MAP = {
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
""" """
Create a tile memory-region descriptor for a BufferLoad. Create a tile memory-region descriptor for a BufferLoad.
Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic
(1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents. (1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents.
Parameters: Parameters:
buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices. buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices.
access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access. access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access.
*args (tir.PrimExpr): Extent expressions for each region dimension. *args (tir.PrimExpr): Extent expressions for each region dimension.
Returns: Returns:
tir.Call: A call to the `tl.region` intrinsic describing the memory region. tir.Call: A call to the `tl.region` intrinsic describing the memory region.
Raises: Raises:
KeyError: If access_type is not one of 'r', 'w', or 'rw'. KeyError: If access_type is not one of 'r', 'w', or 'rw'.
""" """
...@@ -83,15 +83,15 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str, ...@@ -83,15 +83,15 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
extents: List[PrimExpr]): extents: List[PrimExpr]):
""" """
Create a tl region descriptor for the given BufferRegion. Create a tl region descriptor for the given BufferRegion.
Parameters: Parameters:
buffer_region (tir.BufferRegion): Source buffer region whose `region` items provide mins and extents. buffer_region (tir.BufferRegion): Source buffer region whose `region` items provide mins and extents.
access_type (str): Access mode: "r", "w", or "rw". access_type (str): Access mode: "r", "w", or "rw".
extents (List[PrimExpr]): Requested extents; must have length <= the number of extents in buffer_region.region. extents (List[PrimExpr]): Requested extents; must have length <= the number of extents in buffer_region.region.
Returns: Returns:
tir.Call: A tile-region descriptor (tl.region) covering the buffer_region. tir.Call: A tile-region descriptor (tl.region) covering the buffer_region.
Raises: Raises:
AssertionError: If the number of extents in buffer_region.region is smaller than len(extents). AssertionError: If the number of extents in buffer_region.region is smaller than len(extents).
""" """
...@@ -107,15 +107,15 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str, ...@@ -107,15 +107,15 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr: def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
""" """
Perform an atomic maximum on the value stored at dst with an optional memory-order. Perform an atomic maximum on the value stored at dst with an optional memory-order.
If memory_order is None the runtime extern "AtomicMax" is called without an explicit memory-order id; otherwise the provided memory_order string is mapped to a numeric id using the module's memory-order map and passed to the extern. If memory_order is None the runtime extern "AtomicMax" is called without an explicit memory-order id; otherwise the provided memory_order string is mapped to a numeric id using the module's memory-order map and passed to the extern.
Parameters: Parameters:
dst (Buffer): Destination buffer/address to apply the atomic max. dst (Buffer): Destination buffer/address to apply the atomic max.
value (PrimExpr): Value to compare/store atomically. value (PrimExpr): Value to compare/store atomically.
memory_order (str | None): Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst"). memory_order (str | None): Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst").
If provided, it is translated to the corresponding numeric memory-order id before the call. If provided, it is translated to the corresponding numeric memory-order id before the call.
Returns: Returns:
PrimExpr: A handle/expression representing the issued atomic maximum operation. PrimExpr: A handle/expression representing the issued atomic maximum operation.
""" """
...@@ -129,14 +129,14 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> ...@@ -129,14 +129,14 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr: def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
""" """
Atomically update the value at dst to the minimum of its current value and value. Atomically update the value at dst to the minimum of its current value and value.
If memory_order is provided, it selects the memory-order semantic used by the underlying extern call; If memory_order is provided, it selects the memory-order semantic used by the underlying extern call;
allowed names are "relaxed", "consume", "acquire", "release", "acq_rel", and "seq_cst" (mapped internally allowed names are "relaxed", "consume", "acquire", "release", "acq_rel", and "seq_cst" (mapped internally
to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument. to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument.
Parameters: Parameters:
memory_order (str | None): Optional memory-order name controlling the atomic operation's ordering. memory_order (str | None): Optional memory-order name controlling the atomic operation's ordering.
Returns: Returns:
PrimExpr: A handle expression representing the atomic-min operation. PrimExpr: A handle expression representing the atomic-min operation.
""" """
...@@ -150,9 +150,9 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> ...@@ -150,9 +150,9 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr: def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
""" """
Atomically add `value` into `dst`, returning a handle to the operation. Atomically add `value` into `dst`, returning a handle to the operation.
Supports scalar/addressed extern atomic add when neither argument exposes extents, or tile-region-based atomic add for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicAdd` path when no extents are available — otherwise the tile-region path ignores `memory_order`. Supports scalar/addressed extern atomic add when neither argument exposes extents, or tile-region-based atomic add for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicAdd` path when no extents are available — otherwise the tile-region path ignores `memory_order`.
Returns: Returns:
PrimExpr: A handle representing the atomic addition operation. PrimExpr: A handle representing the atomic addition operation.
""" """
...@@ -160,11 +160,11 @@ def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> ...@@ -160,11 +160,11 @@ def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
def get_extent(data): def get_extent(data):
""" """
Return the inferred extent (shape) of a buffer-like object. Return the inferred extent (shape) of a buffer-like object.
If `data` is a Var bound to a let value, the let value is resolved before inspection. If `data` is a Var bound to a let value, the let value is resolved before inspection.
Parameters: Parameters:
data: A Var, Buffer, or BufferRegion to inspect. data: A Var, Buffer, or BufferRegion to inspect.
Returns: Returns:
The shape/extents as a list-like of PrimExpr (Buffer.shape or list of region item extents), or None if the extent cannot be determined. The shape/extents as a list-like of PrimExpr (Buffer.shape or list of region item extents), or None if the extent cannot be determined.
""" """
...@@ -252,12 +252,12 @@ def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr: ...@@ -252,12 +252,12 @@ def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr: def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr:
"""Clamps the input value dst between [min_val, max_val] """Clamps the input value dst between [min_val, max_val]
Args: Args:
dst: Input value to be clamped dst: Input value to be clamped
min_val: Minimum value min_val: Minimum value
max_val: Maximum value max_val: Maximum value
Returns: Returns:
Value clamped to the specified range Value clamped to the specified range
""" """
...@@ -268,7 +268,7 @@ def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr: ...@@ -268,7 +268,7 @@ def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr:
def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer: def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
"""Reshapes the input buffer to the specified shape. """Reshapes the input buffer to the specified shape.
Args: Args:
src (Buffer): Input buffer to be reshaped src (Buffer): Input buffer to be reshaped
shape (List[PrimExpr]): New shape for the buffer shape (List[PrimExpr]): New shape for the buffer
...@@ -284,7 +284,7 @@ def view(src: Buffer, ...@@ -284,7 +284,7 @@ def view(src: Buffer,
dtype: Union[str, None] = None) -> Buffer: dtype: Union[str, None] = None) -> Buffer:
""" """
Return a Tensor view of the input buffer with an optional new shape and dtype. Return a Tensor view of the input buffer with an optional new shape and dtype.
If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy). If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy).
""" """
if shape is None: if shape is None:
...@@ -297,7 +297,7 @@ def view(src: Buffer, ...@@ -297,7 +297,7 @@ def view(src: Buffer,
def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr: def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr:
""" """
Load a value from the given buffer using the specified atomic memory ordering. Load a value from the given buffer using the specified atomic memory ordering.
Performs an atomic load from `src` and returns a PrimExpr representing the loaded value. Performs an atomic load from `src` and returns a PrimExpr representing the loaded value.
memory_order selects the ordering and must be one of: "relaxed", "consume", "acquire", memory_order selects the ordering and must be one of: "relaxed", "consume", "acquire",
"release", "acq_rel", or "seq_cst" (default). "release", "acq_rel", or "seq_cst" (default).
...@@ -310,17 +310,17 @@ def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr: ...@@ -310,17 +310,17 @@ def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr:
def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr: def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr:
""" """
Perform an atomic store of `src` into `dst` with the given memory ordering. Perform an atomic store of `src` into `dst` with the given memory ordering.
Parameters: Parameters:
dst (Buffer): Destination buffer to store into. dst (Buffer): Destination buffer to store into.
src (PrimExpr): Value to store. src (PrimExpr): Value to store.
memory_order (str, optional): Memory ordering name; one of "relaxed", "consume", memory_order (str, optional): Memory ordering name; one of "relaxed", "consume",
"acquire", "release", "acq_rel", or "seq_cst". Defaults to "seq_cst". "acquire", "release", "acq_rel", or "seq_cst". Defaults to "seq_cst".
The name is mapped to an internal numeric ID used by the underlying runtime. The name is mapped to an internal numeric ID used by the underlying runtime.
Returns: Returns:
PrimExpr: A handle representing the issued atomic store operation. PrimExpr: A handle representing the issued atomic store operation.
Raises: Raises:
KeyError: If `memory_order` is not one of the supported names. KeyError: If `memory_order` is not one of the supported names.
""" """
......
...@@ -8,11 +8,11 @@ from tilelang.utils.language import get_buffer_region_from_load ...@@ -8,11 +8,11 @@ from tilelang.utils.language import get_buffer_region_from_load
def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr): def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr):
"""Fill a buffer or buffer region with a specified value. """Fill a buffer or buffer region with a specified value.
Args: Args:
buffer: Either a TVM buffer or buffer region to be filled buffer: Either a TVM buffer or buffer region to be filled
value: The value to fill the buffer with value: The value to fill the buffer with
Returns: Returns:
A TVM intrinsic call that performs the fill operation A TVM intrinsic call that performs the fill operation
""" """
...@@ -23,13 +23,13 @@ def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr): ...@@ -23,13 +23,13 @@ def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr):
def clear(buffer: Union[tir.Buffer, tir.Var]): def clear(buffer: Union[tir.Buffer, tir.Var]):
"""Clear a buffer by filling it with zeros. """Clear a buffer by filling it with zeros.
Args: Args:
buffer: Either a TVM buffer or a variable that contains a buffer region buffer: Either a TVM buffer or a variable that contains a buffer region
Returns: Returns:
A fill operation that sets the buffer contents to zero A fill operation that sets the buffer contents to zero
Raises: Raises:
ValueError: If the buffer variable contains an invalid buffer region ValueError: If the buffer variable contains an invalid buffer region
""" """
......
...@@ -9,10 +9,10 @@ from tilelang.utils.language import get_buffer_elems ...@@ -9,10 +9,10 @@ from tilelang.utils.language import get_buffer_elems
def any_of(buffer: Union[T.Tensor, BufferRegion]): def any_of(buffer: Union[T.Tensor, BufferRegion]):
"""Check if any element in the buffer is true. """Check if any element in the buffer is true.
Args: Args:
buffer: Either a TVM buffer or buffer region to be checked buffer: Either a TVM buffer or buffer region to be checked
Returns: Returns:
A TVM intrinsic call that performs the any operation A TVM intrinsic call that performs the any operation
""" """
...@@ -44,10 +44,10 @@ def any_of(buffer: Union[T.Tensor, BufferRegion]): ...@@ -44,10 +44,10 @@ def any_of(buffer: Union[T.Tensor, BufferRegion]):
def all_of(buffer: Union[T.Tensor, BufferRegion]): def all_of(buffer: Union[T.Tensor, BufferRegion]):
"""Check if all elements in the buffer are true. """Check if all elements in the buffer are true.
Args: Args:
buffer: Either a TVM buffer or buffer region to be checked buffer: Either a TVM buffer or buffer region to be checked
Returns: Returns:
A TVM intrinsic call that performs the any operation A TVM intrinsic call that performs the any operation
""" """
......
...@@ -14,10 +14,10 @@ from tilelang.language.utils import index_to_coordinates ...@@ -14,10 +14,10 @@ from tilelang.language.utils import index_to_coordinates
def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr: def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr:
""" """
Prints the value of a TIR primitive expression (PrimExpr) for debugging purposes. Prints the value of a TIR primitive expression (PrimExpr) for debugging purposes.
Parameters: Parameters:
var (tir.PrimExpr): The variable or expression to be printed. var (tir.PrimExpr): The variable or expression to be printed.
Returns: Returns:
tir.PrimExpr: The TIR expression for the debug print operation. tir.PrimExpr: The TIR expression for the debug print operation.
""" """
...@@ -30,11 +30,11 @@ def print_var_with_condition(condition: tir.PrimExpr, ...@@ -30,11 +30,11 @@ def print_var_with_condition(condition: tir.PrimExpr,
msg: str = "") -> tir.PrimExpr: msg: str = "") -> tir.PrimExpr:
""" """
Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True. Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.
Parameters: Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check. condition (tir.PrimExpr): A TIR expression representing the condition to check.
var (tir.PrimExpr): The variable or expression to be printed. var (tir.PrimExpr): The variable or expression to be printed.
Returns: Returns:
tir.PrimExpr: The TIR expression for the debug print operation, if the condition is True. tir.PrimExpr: The TIR expression for the debug print operation, if the condition is True.
""" """
...@@ -67,12 +67,12 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr, ...@@ -67,12 +67,12 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr,
msg: str = "") -> tir.PrimExpr: msg: str = "") -> tir.PrimExpr:
""" """
Conditionally prints the values of a flattened TIR buffer if the condition is True. Conditionally prints the values of a flattened TIR buffer if the condition is True.
Parameters: Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check. condition (tir.PrimExpr): A TIR expression representing the condition to check.
buffer (tir.Buffer): The buffer whose values need to be printed. buffer (tir.Buffer): The buffer whose values need to be printed.
elems (int): The number of elements in the buffer to print. elems (int): The number of elements in the buffer to print.
Returns: Returns:
tir.PrimExpr: The TIR expression for the debug print operation. tir.PrimExpr: The TIR expression for the debug print operation.
""" """
...@@ -91,12 +91,12 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr, ...@@ -91,12 +91,12 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr,
msg: str = "") -> tir.PrimExpr: msg: str = "") -> tir.PrimExpr:
""" """
Conditionally prints the values of a flattened TIR buffer if the condition is True. Conditionally prints the values of a flattened TIR buffer if the condition is True.
Parameters: Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check. condition (tir.PrimExpr): A TIR expression representing the condition to check.
buffer (tir.Buffer): The buffer whose values need to be printed. buffer (tir.Buffer): The buffer whose values need to be printed.
elems (int): The number of elements in the buffer to print. elems (int): The number of elements in the buffer to print.
Returns: Returns:
tir.PrimExpr: The TIR expression for the debug print operation. tir.PrimExpr: The TIR expression for the debug print operation.
""" """
...@@ -116,12 +116,12 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr, ...@@ -116,12 +116,12 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
msg: str = "") -> tir.PrimExpr: msg: str = "") -> tir.PrimExpr:
""" """
Conditionally prints the values of a flattened TIR buffer if the condition is True. Conditionally prints the values of a flattened TIR buffer if the condition is True.
Parameters: Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check. condition (tir.PrimExpr): A TIR expression representing the condition to check.
buffer (tir.Buffer): The buffer whose values need to be printed. buffer (tir.Buffer): The buffer whose values need to be printed.
elems (int): The number of elements in the buffer to print. elems (int): The number of elements in the buffer to print.
Returns: Returns:
tir.PrimExpr: The TIR expression for the debug print operation. tir.PrimExpr: The TIR expression for the debug print operation.
""" """
...@@ -136,20 +136,20 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr, ...@@ -136,20 +136,20 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr: def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
""" """
A generic print function that handles both TIR buffers and primitive expressions. A generic print function that handles both TIR buffers and primitive expressions.
- If the input is a TIR buffer, it prints its values, but only on the first thread (tx=0, ty=0, tz=0). - If the input is a TIR buffer, it prints its values, but only on the first thread (tx=0, ty=0, tz=0).
- If the input is a TIR primitive expression, it prints its value directly. - If the input is a TIR primitive expression, it prints its value directly.
Parameters: Parameters:
obj (Any): The object to print. It can be either a tir.Buffer or tir.PrimExpr. obj (Any): The object to print. It can be either a tir.Buffer or tir.PrimExpr.
msg (str): An optional message to include in the print statement. msg (str): An optional message to include in the print statement.
warp_group_id (int): The warp group id to print. warp_group_id (int): The warp group id to print.
warp_id (int): The warp id to print. warp_id (int): The warp id to print.
print thread will be warp_group_id * warp_group_size + warp_id. print thread will be warp_group_id * warp_group_size + warp_id.
Returns: Returns:
tir.PrimExpr: The TIR expression for the debug print operation. tir.PrimExpr: The TIR expression for the debug print operation.
Raises: Raises:
ValueError: If the input object type is unsupported. ValueError: If the input object type is unsupported.
""" """
......
...@@ -70,7 +70,7 @@ class BufferProxy: ...@@ -70,7 +70,7 @@ class BufferProxy:
class BaseTensorProxy: class BaseTensorProxy:
"""Base proxy class for tensor types with configurable defaults. """Base proxy class for tensor types with configurable defaults.
This class serves as a foundation for different tensor proxy types, providing This class serves as a foundation for different tensor proxy types, providing
customizable default values for scope, alignment, and offset factors. It implements customizable default values for scope, alignment, and offset factors. It implements
the core functionality for creating TIR buffers with specific memory configurations. the core functionality for creating TIR buffers with specific memory configurations.
...@@ -137,7 +137,7 @@ class BaseTensorProxy: ...@@ -137,7 +137,7 @@ class BaseTensorProxy:
class TensorProxy(BaseTensorProxy): class TensorProxy(BaseTensorProxy):
"""Main tensor proxy class for global scope buffers. """Main tensor proxy class for global scope buffers.
This class implements the default tensor proxy with global memory scope, This class implements the default tensor proxy with global memory scope,
the tensor should be by default contiguous. the tensor should be by default contiguous.
""" """
...@@ -186,7 +186,7 @@ class StridedTensorProxy(BaseTensorProxy): ...@@ -186,7 +186,7 @@ class StridedTensorProxy(BaseTensorProxy):
class FragmentBufferProxy(BaseTensorProxy): class FragmentBufferProxy(BaseTensorProxy):
"""Proxy class for fragment memory buffers. """Proxy class for fragment memory buffers.
This class represents tensor proxies specifically for local fragment memory, This class represents tensor proxies specifically for local fragment memory,
typically used in GPU tensor core operations. typically used in GPU tensor core operations.
""" """
...@@ -195,7 +195,7 @@ class FragmentBufferProxy(BaseTensorProxy): ...@@ -195,7 +195,7 @@ class FragmentBufferProxy(BaseTensorProxy):
class SharedBufferProxy(BaseTensorProxy): class SharedBufferProxy(BaseTensorProxy):
"""Proxy class for shared memory buffers. """Proxy class for shared memory buffers.
This class represents tensor proxies for dynamic shared memory, This class represents tensor proxies for dynamic shared memory,
commonly used in GPU shared memory operations. commonly used in GPU shared memory operations.
""" """
...@@ -204,7 +204,7 @@ class SharedBufferProxy(BaseTensorProxy): ...@@ -204,7 +204,7 @@ class SharedBufferProxy(BaseTensorProxy):
class LocalBufferProxy(BaseTensorProxy): class LocalBufferProxy(BaseTensorProxy):
"""Proxy class for local memory buffers. """Proxy class for local memory buffers.
This class represents tensor proxies for local memory scope, This class represents tensor proxies for local memory scope,
typically used for temporary computations in GPU kernels. typically used for temporary computations in GPU kernels.
""" """
......
...@@ -94,8 +94,8 @@ def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = ...@@ -94,8 +94,8 @@ def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool =
clear (bool, optional): If True, output buffer will be cleared before reduction. clear (bool, optional): If True, output buffer will be cleared before reduction.
If False, results will be accumulated on existing values. If False, results will be accumulated on existing values.
Defaults to True. Defaults to True.
Note: When clear=True, reduce_sum will not compute directly on the output buffer. This is because Note: When clear=True, reduce_sum will not compute directly on the output buffer. This is because
during warp reduction, the same value would be accumulated multiple times (number of threads during warp reduction, the same value would be accumulated multiple times (number of threads
in the warp). Therefore, the implementation with clear=True follows these steps: in the warp). Therefore, the implementation with clear=True follows these steps:
1. create a temp buffer with same shape and dtype as out 1. create a temp buffer with same shape and dtype as out
2. copy out to temp buffer 2. copy out to temp buffer
...@@ -157,9 +157,9 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - ...@@ -157,9 +157,9 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False): def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False):
""" """
Compute the cumulative sum of `src` along `dim`, writing results to `dst`. Compute the cumulative sum of `src` along `dim`, writing results to `dst`.
Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic. Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic.
Returns: Returns:
tir.Call: A handle to the emitted cumulative-sum operation. tir.Call: A handle to the emitted cumulative-sum operation.
""" """
...@@ -187,13 +187,13 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve ...@@ -187,13 +187,13 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve
def finalize_reducer(reducer: tir.Buffer): def finalize_reducer(reducer: tir.Buffer):
""" """
Finalize a reducer buffer by emitting the `tl.finalize_reducer` intrinsic. Finalize a reducer buffer by emitting the `tl.finalize_reducer` intrinsic.
This returns a TVM `tir.Call` handle that finalizes the given reducer using its writable pointer. This returns a TVM `tir.Call` handle that finalizes the given reducer using its writable pointer.
The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR. The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR.
Parameters: Parameters:
reducer (tir.Buffer): Reducer buffer whose writable pointer will be finalized. reducer (tir.Buffer): Reducer buffer whose writable pointer will be finalized.
Returns: Returns:
tir.Call: Handle to the finalize reducer intrinsic call. tir.Call: Handle to the finalize reducer intrinsic call.
""" """
......
...@@ -5,13 +5,13 @@ from tvm.tir import PrimExpr ...@@ -5,13 +5,13 @@ from tvm.tir import PrimExpr
def index_to_coordinates(index, shape) -> list[PrimExpr]: def index_to_coordinates(index, shape) -> list[PrimExpr]:
""" """
Convert a flat (linear) index into multi-dimensional coordinates for a given shape. Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates. Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates.
Parameters: Parameters:
index (int or PrimExpr): The flat index to convert. index (int or PrimExpr): The flat index to convert.
shape (Sequence[int]): The extents of each dimension (length >= 1). shape (Sequence[int]): The extents of each dimension (length >= 1).
Returns: Returns:
list[PrimExpr]: Coordinates for each dimension in the same order as `shape`. list[PrimExpr]: Coordinates for each dimension in the same order as `shape`.
""" """
...@@ -27,26 +27,26 @@ def index_to_coordinates(index, shape) -> list[PrimExpr]: ...@@ -27,26 +27,26 @@ def index_to_coordinates(index, shape) -> list[PrimExpr]:
def linear_index(*args: PrimExpr) -> PrimExpr: def linear_index(*args: PrimExpr) -> PrimExpr:
""" """
Compute a flat (linear) index from multi-dimensional coordinates and strides. Compute a flat (linear) index from multi-dimensional coordinates and strides.
The function accepts a sequence of PrimExpr arguments where the first portion are coordinates The function accepts a sequence of PrimExpr arguments where the first portion are coordinates
and the trailing portion are the corresponding strides. The number of strides must equal and the trailing portion are the corresponding strides. The number of strides must equal
(number of coordinates - 1). The linear index is computed as: (number of coordinates - 1). The linear index is computed as:
linear = coords[0] linear = coords[0]
for each (coord, stride) in zip(coords[1:], strides): for each (coord, stride) in zip(coords[1:], strides):
linear = linear * stride + coord linear = linear * stride + coord
Examples: Examples:
- linear_index(i) -> i - linear_index(i) -> i
- linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed) - linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed)
- linear_index(i, j, stride_j) -> i * stride_j + j - linear_index(i, j, stride_j) -> i * stride_j + j
- linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k - linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k
- linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v - linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v
Raises: Raises:
ValueError: If called with no arguments, or if the number of strides is not one less than ValueError: If called with no arguments, or if the number of strides is not one less than
the number of coordinates. the number of coordinates.
Returns: Returns:
PrimExpr: The computed linear index expression. PrimExpr: The computed linear index expression.
""" """
......
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
from .layout import Layout # noqa: F401 from .layout import Layout # noqa: F401
from .fragment import Fragment # noqa: F401 from .fragment import Fragment # noqa: F401
from .swizzle import make_swizzled_layout # noqa: F401 from .swizzle import make_swizzled_layout # noqa: F401
from .gemm_sp import make_metadata_layout # noqa: F401 from .gemm_sp import make_metadata_layout # noqa: F401
\ No newline at end of file
...@@ -13,8 +13,8 @@ from typing import List ...@@ -13,8 +13,8 @@ from typing import List
class Fragment(Layout): class Fragment(Layout):
""" """
A Fragment layout object that encapsulates iteration variables (forward_vars), A Fragment layout object that encapsulates iteration variables (forward_vars),
thread iteration variables (forward_thread), and index transformations thread iteration variables (forward_thread), and index transformations
(forward_index). This class supports replication (thread_replicate) and (forward_index). This class supports replication (thread_replicate) and
index mapping for fine-grained control over multi-dimensional data layouts. index mapping for fine-grained control over multi-dimensional data layouts.
""" """
...@@ -49,7 +49,7 @@ class Fragment(Layout): ...@@ -49,7 +49,7 @@ class Fragment(Layout):
used for multi-threading or replication in the hardware threads. Defaults to 1. used for multi-threading or replication in the hardware threads. Defaults to 1.
forward_index_fn : callable, optional forward_index_fn : callable, optional
A function that takes iteration variables and returns an index or list A function that takes iteration variables and returns an index or list
of indices for this fragment. Used when `forward_fn` is None and of indices for this fragment. Used when `forward_fn` is None and
the index transformation is derived separately. the index transformation is derived separately.
""" """
......
"""Library information. This is a standalone file that can be used to get various info. """Library information. This is a standalone file that can be used to get various info.
Modified from: https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/libinfo.py Modified from: https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/libinfo.py
""" """
......
...@@ -20,7 +20,7 @@ from tilelang.profiler.bench import do_bench ...@@ -20,7 +20,7 @@ from tilelang.profiler.bench import do_bench
@dataclass @dataclass
class Profiler: class Profiler:
"""A profiler class for benchmarking and validating kernel implementations. """A profiler class for benchmarking and validating kernel implementations.
Attributes: Attributes:
params: List of kernel parameters defining the input/output specifications params: List of kernel parameters defining the input/output specifications
result_idx: Indices indicating which parameters are output tensors result_idx: Indices indicating which parameters are output tensors
...@@ -82,7 +82,7 @@ class Profiler: ...@@ -82,7 +82,7 @@ class Profiler:
max_mismatched_ratio=0.01, max_mismatched_ratio=0.01,
): ):
"""Validates kernel output against a reference implementation. """Validates kernel output against a reference implementation.
Args: Args:
reference_program: Reference implementation to compare against reference_program: Reference implementation to compare against
input_tensors: Optional pre-generated input tensors input_tensors: Optional pre-generated input tensors
...@@ -151,7 +151,7 @@ class Profiler: ...@@ -151,7 +151,7 @@ class Profiler:
manual_check_prog: Callable = None, manual_check_prog: Callable = None,
): ):
"""Validates kernel output against a reference implementation. """Validates kernel output against a reference implementation.
Args: Args:
reference_program: Reference implementation to compare against reference_program: Reference implementation to compare against
input_tensors: Optional pre-generated input tensors input_tensors: Optional pre-generated input tensors
...@@ -177,7 +177,7 @@ class Profiler: ...@@ -177,7 +177,7 @@ class Profiler:
def assert_consistent(self, repeat=10): def assert_consistent(self, repeat=10):
"""Checks for kernel consistency across multiple runs. """Checks for kernel consistency across multiple runs.
Args: Args:
repeat: Number of times to repeat the consistency check repeat: Number of times to repeat the consistency check
""" """
...@@ -202,11 +202,11 @@ class Profiler: ...@@ -202,11 +202,11 @@ class Profiler:
def determine_profiler(self, func: Optional[Callable] = None): def determine_profiler(self, func: Optional[Callable] = None):
"""Determines which profiler backend to use based on function type. """Determines which profiler backend to use based on function type.
Args: Args:
func: Function to be profiled func: Function to be profiled
profiler: Explicitly specified profiler type or "auto" for automatic detection profiler: Explicitly specified profiler type or "auto" for automatic detection
Returns: Returns:
str: The determined profiler type ("torch" or "tvm") str: The determined profiler type ("torch" or "tvm")
""" """
...@@ -225,7 +225,7 @@ class Profiler: ...@@ -225,7 +225,7 @@ class Profiler:
input_tensors: List[torch.Tensor] = None, input_tensors: List[torch.Tensor] = None,
) -> float: ) -> float:
"""Benchmarks the execution time of a given function. """Benchmarks the execution time of a given function.
Args: Args:
func: Function to benchmark (uses adapter if None) func: Function to benchmark (uses adapter if None)
warmup: Warmup time in milliseconds warmup: Warmup time in milliseconds
...@@ -234,7 +234,7 @@ class Profiler: ...@@ -234,7 +234,7 @@ class Profiler:
n_repeat: Number of timing iterations n_repeat: Number of timing iterations
profiler: Which profiling backend to use profiler: Which profiling backend to use
input_tensors: Optional pre-generated input tensors input_tensors: Optional pre-generated input tensors
Returns: Returns:
float: Average execution time in milliseconds float: Average execution time in milliseconds
""" """
......
...@@ -16,13 +16,13 @@ def do_bench( ...@@ -16,13 +16,13 @@ def do_bench(
return_mode: Literal["min", "max", "mean", "median"] = "mean", return_mode: Literal["min", "max", "mean", "median"] = "mean",
) -> Union[float, List[float]]: ) -> Union[float, List[float]]:
"""Benchmarks the runtime of a PyTorch function. """Benchmarks the runtime of a PyTorch function.
This function handles: This function handles:
- L2 cache flushing between runs for consistent timing - L2 cache flushing between runs for consistent timing
- Automatic warmup and repeat count calculation - Automatic warmup and repeat count calculation
- Optional gradient clearing for backward passes - Optional gradient clearing for backward passes
- Multiple measurement modes (mean, median, min, max) - Multiple measurement modes (mean, median, min, max)
Args: Args:
fn: Function to benchmark fn: Function to benchmark
warmup: Target warmup time in milliseconds warmup: Target warmup time in milliseconds
...@@ -33,7 +33,7 @@ def do_bench( ...@@ -33,7 +33,7 @@ def do_bench(
quantiles: Optional performance percentiles to compute quantiles: Optional performance percentiles to compute
fast_flush: Whether to use faster L2 cache flushing fast_flush: Whether to use faster L2 cache flushing
return_mode: How to aggregate timing results ("mean", "median", "min", "max") return_mode: How to aggregate timing results ("mean", "median", "min", "max")
Returns: Returns:
float: Aggregated runtime in milliseconds float: Aggregated runtime in milliseconds
""" """
......
...@@ -377,14 +377,14 @@ __device__ void decode_i4b_to_f16_scale_zeros_quantized_offset(T1 *_i4s, T2 *B_l ...@@ -377,14 +377,14 @@ __device__ void decode_i4b_to_f16_scale_zeros_quantized_offset(T1 *_i4s, T2 *B_l
T3 const scale_r = *(scale + scale_offset); T3 const scale_r = *(scale + scale_offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l); uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r); uint const packed_scales_r = __pack_half2(scale_r, scale_r);
const int num_elems_per_storage_dtype = sizeof(T1) * 8 / 4; const int num_elems_per_storage_dtype = sizeof(T1) * 8 / 4;
T1 const qzeros_l = *qzeros; T1 const qzeros_l = *qzeros;
T1 const qzeros_r = *(qzeros + qzeros_offset); T1 const qzeros_r = *(qzeros + qzeros_offset);
int16_t const zero_l = (qzeros_l >> (group_offset * 4) & 0xf); int16_t const zero_l = (qzeros_l >> (group_offset * 4) & 0xf);
int16_t const zero_r = (qzeros_r >> (group_offset * 4) & 0xf); int16_t const zero_r = (qzeros_r >> (group_offset * 4) & 0xf);
uint median_num_l = ((0xe400 | zero_l) << 16) | (0xe400 | zero_l); uint median_num_l = ((0xe400 | zero_l) << 16) | (0xe400 | zero_l);
uint median_num_r = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); uint median_num_r = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r);
......
...@@ -17,7 +17,7 @@ __device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, co ...@@ -17,7 +17,7 @@ __device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, co
"and.b32 %0, %13, 0b10000001110000001000000111000000;" "and.b32 %0, %13, 0b10000001110000001000000111000000;"
"mul.bf16x2 %0, %0, %12;" "mul.bf16x2 %0, %0, %12;"
"shl.b32 %1, %13, 3;" "shl.b32 %1, %13, 3;"
"and.b32 %1, %1, 0b10000001110000001000000111000000;" "and.b32 %1, %1, 0b10000001110000001000000111000000;"
"mul.bf16x2 %1, %1, %12;" "mul.bf16x2 %1, %1, %12;"
"shl.b32 %2, %13, 6;" "shl.b32 %2, %13, 6;"
"and.b32 %2, %2, 0b10000001110000001000000111000000;" "and.b32 %2, %2, 0b10000001110000001000000111000000;"
...@@ -41,7 +41,7 @@ __device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, co ...@@ -41,7 +41,7 @@ __device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, co
// Pay attention to the big-endianness issue // Pay attention to the big-endianness issue
B_local_decode[(i << 3) + j] = reinterpret_cast<T2*>(&B_dequantize_local_vec[j])[1]; B_local_decode[(i << 3) + j] = reinterpret_cast<T2*>(&B_dequantize_local_vec[j])[1];
B_local_decode[(i << 3) + j + 4] = reinterpret_cast<T2*>(&B_dequantize_local_vec[j])[0]; B_local_decode[(i << 3) + j + 4] = reinterpret_cast<T2*>(&B_dequantize_local_vec[j])[0];
} }
} }
// Check if the synchronization is needed // Check if the synchronization is needed
} }
...@@ -57,25 +57,25 @@ def get_mxfp_intrin_group( ...@@ -57,25 +57,25 @@ def get_mxfp_intrin_group(
) -> Dict[str, str]: ) -> Dict[str, str]:
""" """
Return metadata for an MXFP decoding intrinsic: function name and C source string. Return metadata for an MXFP decoding intrinsic: function name and C source string.
Validates the requested output dtype, source format, and storage dtype, then constructs Validates the requested output dtype, source format, and storage dtype, then constructs
a lookup key of the form `fp{source_bit}_to_{f16|bf16}` (appending `_twiddling` when a lookup key of the form `fp{source_bit}_to_{f16|bf16}` (appending `_twiddling` when
use_twiddling is True) to select the corresponding C source snippet and a matching use_twiddling is True) to select the corresponding C source snippet and a matching
function name `decode_fp{source_bit}_to_{f16|bf16}` (also optionally suffixed with function name `decode_fp{source_bit}_to_{f16|bf16}` (also optionally suffixed with
`_twiddling`). `_twiddling`).
Parameters: Parameters:
out_dtype: Target floating-point type for decoded values; either "float16" or "bfloat16". out_dtype: Target floating-point type for decoded values; either "float16" or "bfloat16".
source_format: Integer source representation; "int" or "uint". source_format: Integer source representation; "int" or "uint".
source_bit: Bit width of the packed source format (e.g., 4). source_bit: Bit width of the packed source format (e.g., 4).
storage_dtype: Underlying storage integer dtype (one of "int32", "int8", "uint8"). storage_dtype: Underlying storage integer dtype (one of "int32", "int8", "uint8").
use_twiddling: When True, select the twiddling variant of the decoding intrinsic. use_twiddling: When True, select the twiddling variant of the decoding intrinsic.
Returns: Returns:
A dict with: A dict with:
- "func_name": the generated C function name string for the requested decode intrinsic. - "func_name": the generated C function name string for the requested decode intrinsic.
- "c_source": the C source string for that intrinsic. - "c_source": the C source string for that intrinsic.
Raises: Raises:
AssertionError: if out_dtype, source_format, or storage_dtype are not supported. AssertionError: if out_dtype, source_format, or storage_dtype are not supported.
KeyError: if the constructed key does not match any available C source implementation. KeyError: if the constructed key does not match any available C source implementation.
......
...@@ -31,10 +31,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -31,10 +31,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
dtype: str): dtype: str):
""" """
Convert a packed 4-bit field stored in a uint8 into a bfloat16 value using an exponent scale. Convert a packed 4-bit field stored in a uint8 into a bfloat16 value using an exponent scale.
This function expects a storage field of width `nbit == 4` packed into the 8-bit input `val` and returns This function expects a storage field of width `nbit == 4` packed into the 8-bit input `val` and returns
a bfloat16 constructed from the unpacked sign, a scaled exponent, and the 1-bit mantissa. a bfloat16 constructed from the unpacked sign, a scaled exponent, and the 1-bit mantissa.
Behavior: Behavior:
- Validates `nbit == 4`, `dtype == "bfloat16"`, and `val.dtype == "uint8"` (AssertionError if violated). - Validates `nbit == 4`, `dtype == "bfloat16"`, and `val.dtype == "uint8"` (AssertionError if violated).
- Extracts the 4-bit field at position `pos` (fields are packed consecutively in `val`). - Extracts the 4-bit field at position `pos` (fields are packed consecutively in `val`).
...@@ -43,14 +43,14 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -43,14 +43,14 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
and clamps the result to the 8-bit exponent range (0..255). and clamps the result to the 8-bit exponent range (0..255).
- Assembles a 16-bit bfloat16 bit pattern from (sign, biased-and-scaled-exponent, mantissa) and - Assembles a 16-bit bfloat16 bit pattern from (sign, biased-and-scaled-exponent, mantissa) and
returns it reinterpreted as `bfloat16`. returns it reinterpreted as `bfloat16`.
Parameters: Parameters:
- nbit: must be 4 (width of the packed field). - nbit: must be 4 (width of the packed field).
- val: uint8 expression containing packed fields. - val: uint8 expression containing packed fields.
- pos: index of the field within `val` (0-based); used to compute the bit shift. - pos: index of the field within `val` (0-based); used to compute the bit shift.
- scale: exponent-scale to add to the converted exponent (treated as an unsigned integer expression). - scale: exponent-scale to add to the converted exponent (treated as an unsigned integer expression).
- dtype: must be "bfloat16". - dtype: must be "bfloat16".
Returns: Returns:
- A tir.PrimExpr of dtype "bfloat16" representing the decoded and scaled value. - A tir.PrimExpr of dtype "bfloat16" representing the decoded and scaled value.
""" """
...@@ -75,16 +75,16 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -75,16 +75,16 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True): def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True):
""" """
Convert two float32 values to bfloat16 and pack them into a single uint32. Convert two float32 values to bfloat16 and pack them into a single uint32.
The two inputs v0 and v1 (float32 PrimExpr) are reinterpreted as uint32 bit patterns, optionally rounded to nearest-even The two inputs v0 and v1 (float32 PrimExpr) are reinterpreted as uint32 bit patterns, optionally rounded to nearest-even
by adding a rounding bias, then truncated to their upper 16 bits (bfloat16 representation). The two 16-bit results are by adding a rounding bias, then truncated to their upper 16 bits (bfloat16 representation). The two 16-bit results are
packed into a uint32 with v0 in the lower 16 bits and v1 in the upper 16 bits. packed into a uint32 with v0 in the lower 16 bits and v1 in the upper 16 bits.
Parameters: Parameters:
v0 (tir.PrimExpr): First float32 value to convert and pack. v0 (tir.PrimExpr): First float32 value to convert and pack.
v1 (tir.PrimExpr): Second float32 value to convert and pack. v1 (tir.PrimExpr): Second float32 value to convert and pack.
round_to_even (bool): If True, apply round-to-nearest-even bias before truncation (default True). round_to_even (bool): If True, apply round-to-nearest-even bias before truncation (default True).
Returns: Returns:
tir.PrimExpr: A uint32 PrimExpr containing the packed bfloat16 representations (v0 low 16 bits, v1 high 16 bits). tir.PrimExpr: A uint32 PrimExpr containing the packed bfloat16 representations (v0 low 16 bits, v1 high 16 bits).
""" """
......
...@@ -76,7 +76,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): ...@@ -76,7 +76,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
Returns: Returns:
_type_: _description_ _type_: _description_
Example: Example:
qweight = torch.randint(0, 127, (10, 10), dtype=torch.int8).cuda() qweight = torch.randint(0, 127, (10, 10), dtype=torch.int8).cuda()
interleave_weight(qweight, 4, "float16") interleave_weight(qweight, 4, "float16")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment