Unverified Commit 5e529522 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Add ruff config to check for useless spaces (#807)

* update lint config

* Remove spaces for blank line

* update
parent 4d54854b
......@@ -152,7 +152,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
def _process_dynamic_symbolic(self):
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
for runtime shape resolution.
"""
......@@ -179,17 +179,17 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
*ins: List[torch.Tensor],
stream: Optional[int] = None):
"""High-level wrapper for kernel execution.
Handles:
1. Input validation
2. Output tensor allocation
3. Dynamic shape resolution
4. CUDA stream management
Args:
ins: Input PyTorch tensors
stream: Optional CUDA stream for asynchronous execution
Returns:
Single tensor or list of tensors containing the kernel results
"""
......
......@@ -17,7 +17,7 @@
# This file is modified from the original version,
# which is part of the flashinfer project
# (https://github.com/flashinfer-ai/flashinfer).
"""Library information. This is a standalone file that can be used to get various info.
"""Library information. This is a standalone file that can be used to get various info.
Modified from flashinfer
"""
......
......@@ -80,11 +80,11 @@ from .utils import index_to_coordinates # noqa: F401
def symbolic(name: str, dtype: str = "int32"):
"""
Create a TIR symbolic variable.
Parameters:
name (str): Identifier for the variable in generated TIR.
dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32".
Returns:
tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
"""
......@@ -108,7 +108,7 @@ def annotate_layout(layout_map: Dict):
Returns:
block_attr: a block attribute
Example:
@T.prim_func
def main(
......@@ -149,7 +149,7 @@ def annotate_padding(padding_map: Dict):
Returns:
block_attr: a block attribute
Example:
@T.prim_func
def main(
......
......@@ -29,7 +29,7 @@ def create_list_of_mbarrier(*args: Any) -> Call:
------
TypeError
If the input is not a list or variadic arguments.
Examples
--------
>>> create_list_of_mbarrier([128, 128])
......
......@@ -20,18 +20,18 @@ _MEMORY_ORDER_ID_MAP = {
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
"""
Create a tile memory-region descriptor for a BufferLoad.
Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic
(1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents.
Parameters:
buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices.
access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access.
*args (tir.PrimExpr): Extent expressions for each region dimension.
Returns:
tir.Call: A call to the `tl.region` intrinsic describing the memory region.
Raises:
KeyError: If access_type is not one of 'r', 'w', or 'rw'.
"""
......@@ -83,15 +83,15 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
extents: List[PrimExpr]):
"""
Create a tl region descriptor for the given BufferRegion.
Parameters:
buffer_region (tir.BufferRegion): Source buffer region whose `region` items provide mins and extents.
access_type (str): Access mode: "r", "w", or "rw".
extents (List[PrimExpr]): Requested extents; must have length <= the number of extents in buffer_region.region.
Returns:
tir.Call: A tile-region descriptor (tl.region) covering the buffer_region.
Raises:
AssertionError: If the number of extents in buffer_region.region is smaller than len(extents).
"""
......@@ -107,15 +107,15 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""
Perform an atomic maximum on the value stored at dst with an optional memory-order.
If memory_order is None the runtime extern "AtomicMax" is called without an explicit memory-order id; otherwise the provided memory_order string is mapped to a numeric id using the module's memory-order map and passed to the extern.
Parameters:
dst (Buffer): Destination buffer/address to apply the atomic max.
value (PrimExpr): Value to compare/store atomically.
memory_order (str | None): Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst").
If provided, it is translated to the corresponding numeric memory-order id before the call.
Returns:
PrimExpr: A handle/expression representing the issued atomic maximum operation.
"""
......@@ -129,14 +129,14 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""
Atomically update the value at dst to the minimum of its current value and value.
If memory_order is provided, it selects the memory-order semantic used by the underlying extern call;
allowed names are "relaxed", "consume", "acquire", "release", "acq_rel", and "seq_cst" (mapped internally
to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument.
Parameters:
memory_order (str | None): Optional memory-order name controlling the atomic operation's ordering.
Returns:
PrimExpr: A handle expression representing the atomic-min operation.
"""
......@@ -150,9 +150,9 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""
Atomically add `value` into `dst`, returning a handle to the operation.
Supports scalar/addressed extern atomic add when neither argument exposes extents, or tile-region-based atomic add for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicAdd` path when no extents are available — otherwise the tile-region path ignores `memory_order`.
Returns:
PrimExpr: A handle representing the atomic addition operation.
"""
......@@ -160,11 +160,11 @@ def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
def get_extent(data):
"""
Return the inferred extent (shape) of a buffer-like object.
If `data` is a Var bound to a let value, the let value is resolved before inspection.
Parameters:
data: A Var, Buffer, or BufferRegion to inspect.
Returns:
The shape/extents as a list-like of PrimExpr (Buffer.shape or list of region item extents), or None if the extent cannot be determined.
"""
......@@ -252,12 +252,12 @@ def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr:
"""Clamps the input value dst between [min_val, max_val]
Args:
dst: Input value to be clamped
min_val: Minimum value
max_val: Maximum value
Returns:
Value clamped to the specified range
"""
......@@ -268,7 +268,7 @@ def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr:
def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
"""Reshapes the input buffer to the specified shape.
Args:
src (Buffer): Input buffer to be reshaped
shape (List[PrimExpr]): New shape for the buffer
......@@ -284,7 +284,7 @@ def view(src: Buffer,
dtype: Union[str, None] = None) -> Buffer:
"""
Return a Tensor view of the input buffer with an optional new shape and dtype.
If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy).
"""
if shape is None:
......@@ -297,7 +297,7 @@ def view(src: Buffer,
def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr:
"""
Load a value from the given buffer using the specified atomic memory ordering.
Performs an atomic load from `src` and returns a PrimExpr representing the loaded value.
memory_order selects the ordering and must be one of: "relaxed", "consume", "acquire",
"release", "acq_rel", or "seq_cst" (default).
......@@ -310,17 +310,17 @@ def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr:
def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr:
"""
Perform an atomic store of `src` into `dst` with the given memory ordering.
Parameters:
dst (Buffer): Destination buffer to store into.
src (PrimExpr): Value to store.
memory_order (str, optional): Memory ordering name; one of "relaxed", "consume",
"acquire", "release", "acq_rel", or "seq_cst". Defaults to "seq_cst".
The name is mapped to an internal numeric ID used by the underlying runtime.
Returns:
PrimExpr: A handle representing the issued atomic store operation.
Raises:
KeyError: If `memory_order` is not one of the supported names.
"""
......
......@@ -8,11 +8,11 @@ from tilelang.utils.language import get_buffer_region_from_load
def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr):
"""Fill a buffer or buffer region with a specified value.
Args:
buffer: Either a TVM buffer or buffer region to be filled
value: The value to fill the buffer with
Returns:
A TVM intrinsic call that performs the fill operation
"""
......@@ -23,13 +23,13 @@ def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr):
def clear(buffer: Union[tir.Buffer, tir.Var]):
"""Clear a buffer by filling it with zeros.
Args:
buffer: Either a TVM buffer or a variable that contains a buffer region
Returns:
A fill operation that sets the buffer contents to zero
Raises:
ValueError: If the buffer variable contains an invalid buffer region
"""
......
......@@ -9,10 +9,10 @@ from tilelang.utils.language import get_buffer_elems
def any_of(buffer: Union[T.Tensor, BufferRegion]):
"""Check if any element in the buffer is true.
Args:
buffer: Either a TVM buffer or buffer region to be checked
Returns:
A TVM intrinsic call that performs the any operation
"""
......@@ -44,10 +44,10 @@ def any_of(buffer: Union[T.Tensor, BufferRegion]):
def all_of(buffer: Union[T.Tensor, BufferRegion]):
"""Check if all elements in the buffer are true.
Args:
buffer: Either a TVM buffer or buffer region to be checked
Returns:
A TVM intrinsic call that performs the any operation
"""
......
......@@ -14,10 +14,10 @@ from tilelang.language.utils import index_to_coordinates
def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr:
"""
Prints the value of a TIR primitive expression (PrimExpr) for debugging purposes.
Parameters:
var (tir.PrimExpr): The variable or expression to be printed.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
"""
......@@ -30,11 +30,11 @@ def print_var_with_condition(condition: tir.PrimExpr,
msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.
Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check.
var (tir.PrimExpr): The variable or expression to be printed.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation, if the condition is True.
"""
......@@ -67,12 +67,12 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr,
msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check.
buffer (tir.Buffer): The buffer whose values need to be printed.
elems (int): The number of elements in the buffer to print.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
"""
......@@ -91,12 +91,12 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr,
msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check.
buffer (tir.Buffer): The buffer whose values need to be printed.
elems (int): The number of elements in the buffer to print.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
"""
......@@ -116,12 +116,12 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check.
buffer (tir.Buffer): The buffer whose values need to be printed.
elems (int): The number of elements in the buffer to print.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
"""
......@@ -136,20 +136,20 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
"""
A generic print function that handles both TIR buffers and primitive expressions.
- If the input is a TIR buffer, it prints its values, but only on the first thread (tx=0, ty=0, tz=0).
- If the input is a TIR primitive expression, it prints its value directly.
Parameters:
obj (Any): The object to print. It can be either a tir.Buffer or tir.PrimExpr.
msg (str): An optional message to include in the print statement.
warp_group_id (int): The warp group id to print.
warp_id (int): The warp id to print.
print thread will be warp_group_id * warp_group_size + warp_id.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
Raises:
ValueError: If the input object type is unsupported.
"""
......
......@@ -70,7 +70,7 @@ class BufferProxy:
class BaseTensorProxy:
"""Base proxy class for tensor types with configurable defaults.
This class serves as a foundation for different tensor proxy types, providing
customizable default values for scope, alignment, and offset factors. It implements
the core functionality for creating TIR buffers with specific memory configurations.
......@@ -137,7 +137,7 @@ class BaseTensorProxy:
class TensorProxy(BaseTensorProxy):
"""Main tensor proxy class for global scope buffers.
This class implements the default tensor proxy with global memory scope,
the tensor should be by default contiguous.
"""
......@@ -186,7 +186,7 @@ class StridedTensorProxy(BaseTensorProxy):
class FragmentBufferProxy(BaseTensorProxy):
"""Proxy class for fragment memory buffers.
This class represents tensor proxies specifically for local fragment memory,
typically used in GPU tensor core operations.
"""
......@@ -195,7 +195,7 @@ class FragmentBufferProxy(BaseTensorProxy):
class SharedBufferProxy(BaseTensorProxy):
"""Proxy class for shared memory buffers.
This class represents tensor proxies for dynamic shared memory,
commonly used in GPU shared memory operations.
"""
......@@ -204,7 +204,7 @@ class SharedBufferProxy(BaseTensorProxy):
class LocalBufferProxy(BaseTensorProxy):
"""Proxy class for local memory buffers.
This class represents tensor proxies for local memory scope,
typically used for temporary computations in GPU kernels.
"""
......
......@@ -94,8 +94,8 @@ def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool =
clear (bool, optional): If True, output buffer will be cleared before reduction.
If False, results will be accumulated on existing values.
Defaults to True.
Note: When clear=True, reduce_sum will not compute directly on the output buffer. This is because
during warp reduction, the same value would be accumulated multiple times (number of threads
Note: When clear=True, reduce_sum will not compute directly on the output buffer. This is because
during warp reduction, the same value would be accumulated multiple times (number of threads
in the warp). Therefore, the implementation with clear=True follows these steps:
1. create a temp buffer with same shape and dtype as out
2. copy out to temp buffer
......@@ -157,9 +157,9 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False):
"""
Compute the cumulative sum of `src` along `dim`, writing results to `dst`.
Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic.
Returns:
tir.Call: A handle to the emitted cumulative-sum operation.
"""
......@@ -187,13 +187,13 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve
def finalize_reducer(reducer: tir.Buffer):
"""
Finalize a reducer buffer by emitting the `tl.finalize_reducer` intrinsic.
This returns a TVM `tir.Call` handle that finalizes the given reducer using its writable pointer.
The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR.
Parameters:
reducer (tir.Buffer): Reducer buffer whose writable pointer will be finalized.
Returns:
tir.Call: Handle to the finalize reducer intrinsic call.
"""
......
......@@ -5,13 +5,13 @@ from tvm.tir import PrimExpr
def index_to_coordinates(index, shape) -> list[PrimExpr]:
"""
Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates.
Parameters:
index (int or PrimExpr): The flat index to convert.
shape (Sequence[int]): The extents of each dimension (length >= 1).
Returns:
list[PrimExpr]: Coordinates for each dimension in the same order as `shape`.
"""
......@@ -27,26 +27,26 @@ def index_to_coordinates(index, shape) -> list[PrimExpr]:
def linear_index(*args: PrimExpr) -> PrimExpr:
"""
Compute a flat (linear) index from multi-dimensional coordinates and strides.
The function accepts a sequence of PrimExpr arguments where the first portion are coordinates
and the trailing portion are the corresponding strides. The number of strides must equal
(number of coordinates - 1). The linear index is computed as:
linear = coords[0]
for each (coord, stride) in zip(coords[1:], strides):
linear = linear * stride + coord
Examples:
- linear_index(i) -> i
- linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed)
- linear_index(i, j, stride_j) -> i * stride_j + j
- linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k
- linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v
Raises:
ValueError: If called with no arguments, or if the number of strides is not one less than
the number of coordinates.
Returns:
PrimExpr: The computed linear index expression.
"""
......
......@@ -4,4 +4,4 @@
from .layout import Layout # noqa: F401
from .fragment import Fragment # noqa: F401
from .swizzle import make_swizzled_layout # noqa: F401
from .gemm_sp import make_metadata_layout # noqa: F401
\ No newline at end of file
from .gemm_sp import make_metadata_layout # noqa: F401
......@@ -13,8 +13,8 @@ from typing import List
class Fragment(Layout):
"""
A Fragment layout object that encapsulates iteration variables (forward_vars),
thread iteration variables (forward_thread), and index transformations
(forward_index). This class supports replication (thread_replicate) and
thread iteration variables (forward_thread), and index transformations
(forward_index). This class supports replication (thread_replicate) and
index mapping for fine-grained control over multi-dimensional data layouts.
"""
......@@ -49,7 +49,7 @@ class Fragment(Layout):
used for multi-threading or replication in the hardware threads. Defaults to 1.
forward_index_fn : callable, optional
A function that takes iteration variables and returns an index or list
of indices for this fragment. Used when `forward_fn` is None and
of indices for this fragment. Used when `forward_fn` is None and
the index transformation is derived separately.
"""
......
"""Library information. This is a standalone file that can be used to get various info.
"""Library information. This is a standalone file that can be used to get various info.
Modified from: https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/libinfo.py
"""
......
......@@ -20,7 +20,7 @@ from tilelang.profiler.bench import do_bench
@dataclass
class Profiler:
"""A profiler class for benchmarking and validating kernel implementations.
Attributes:
params: List of kernel parameters defining the input/output specifications
result_idx: Indices indicating which parameters are output tensors
......@@ -82,7 +82,7 @@ class Profiler:
max_mismatched_ratio=0.01,
):
"""Validates kernel output against a reference implementation.
Args:
reference_program: Reference implementation to compare against
input_tensors: Optional pre-generated input tensors
......@@ -151,7 +151,7 @@ class Profiler:
manual_check_prog: Callable = None,
):
"""Validates kernel output against a reference implementation.
Args:
reference_program: Reference implementation to compare against
input_tensors: Optional pre-generated input tensors
......@@ -177,7 +177,7 @@ class Profiler:
def assert_consistent(self, repeat=10):
"""Checks for kernel consistency across multiple runs.
Args:
repeat: Number of times to repeat the consistency check
"""
......@@ -202,11 +202,11 @@ class Profiler:
def determine_profiler(self, func: Optional[Callable] = None):
"""Determines which profiler backend to use based on function type.
Args:
func: Function to be profiled
profiler: Explicitly specified profiler type or "auto" for automatic detection
Returns:
str: The determined profiler type ("torch" or "tvm")
"""
......@@ -225,7 +225,7 @@ class Profiler:
input_tensors: List[torch.Tensor] = None,
) -> float:
"""Benchmarks the execution time of a given function.
Args:
func: Function to benchmark (uses adapter if None)
warmup: Warmup time in milliseconds
......@@ -234,7 +234,7 @@ class Profiler:
n_repeat: Number of timing iterations
profiler: Which profiling backend to use
input_tensors: Optional pre-generated input tensors
Returns:
float: Average execution time in milliseconds
"""
......
......@@ -16,13 +16,13 @@ def do_bench(
return_mode: Literal["min", "max", "mean", "median"] = "mean",
) -> Union[float, List[float]]:
"""Benchmarks the runtime of a PyTorch function.
This function handles:
- L2 cache flushing between runs for consistent timing
- Automatic warmup and repeat count calculation
- Optional gradient clearing for backward passes
- Multiple measurement modes (mean, median, min, max)
Args:
fn: Function to benchmark
warmup: Target warmup time in milliseconds
......@@ -33,7 +33,7 @@ def do_bench(
quantiles: Optional performance percentiles to compute
fast_flush: Whether to use faster L2 cache flushing
return_mode: How to aggregate timing results ("mean", "median", "min", "max")
Returns:
float: Aggregated runtime in milliseconds
"""
......
......@@ -377,14 +377,14 @@ __device__ void decode_i4b_to_f16_scale_zeros_quantized_offset(T1 *_i4s, T2 *B_l
T3 const scale_r = *(scale + scale_offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r);
const int num_elems_per_storage_dtype = sizeof(T1) * 8 / 4;
T1 const qzeros_l = *qzeros;
T1 const qzeros_r = *(qzeros + qzeros_offset);
int16_t const zero_l = (qzeros_l >> (group_offset * 4) & 0xf);
int16_t const zero_r = (qzeros_r >> (group_offset * 4) & 0xf);
uint median_num_l = ((0xe400 | zero_l) << 16) | (0xe400 | zero_l);
uint median_num_r = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r);
......
......@@ -17,7 +17,7 @@ __device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, co
"and.b32 %0, %13, 0b10000001110000001000000111000000;"
"mul.bf16x2 %0, %0, %12;"
"shl.b32 %1, %13, 3;"
"and.b32 %1, %1, 0b10000001110000001000000111000000;"
"and.b32 %1, %1, 0b10000001110000001000000111000000;"
"mul.bf16x2 %1, %1, %12;"
"shl.b32 %2, %13, 6;"
"and.b32 %2, %2, 0b10000001110000001000000111000000;"
......@@ -41,7 +41,7 @@ __device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, co
// Pay attention to the big-endianness issue
B_local_decode[(i << 3) + j] = reinterpret_cast<T2*>(&B_dequantize_local_vec[j])[1];
B_local_decode[(i << 3) + j + 4] = reinterpret_cast<T2*>(&B_dequantize_local_vec[j])[0];
}
}
}
// Check if the synchronization is needed
}
......@@ -57,25 +57,25 @@ def get_mxfp_intrin_group(
) -> Dict[str, str]:
"""
Return metadata for an MXFP decoding intrinsic: function name and C source string.
Validates the requested output dtype, source format, and storage dtype, then constructs
a lookup key of the form `fp{source_bit}_to_{f16|bf16}` (appending `_twiddling` when
use_twiddling is True) to select the corresponding C source snippet and a matching
function name `decode_fp{source_bit}_to_{f16|bf16}` (also optionally suffixed with
`_twiddling`).
Parameters:
out_dtype: Target floating-point type for decoded values; either "float16" or "bfloat16".
source_format: Integer source representation; "int" or "uint".
source_bit: Bit width of the packed source format (e.g., 4).
storage_dtype: Underlying storage integer dtype (one of "int32", "int8", "uint8").
use_twiddling: When True, select the twiddling variant of the decoding intrinsic.
Returns:
A dict with:
- "func_name": the generated C function name string for the requested decode intrinsic.
- "c_source": the C source string for that intrinsic.
Raises:
AssertionError: if out_dtype, source_format, or storage_dtype are not supported.
KeyError: if the constructed key does not match any available C source implementation.
......
......@@ -31,10 +31,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
dtype: str):
"""
Convert a packed 4-bit field stored in a uint8 into a bfloat16 value using an exponent scale.
This function expects a storage field of width `nbit == 4` packed into the 8-bit input `val` and returns
a bfloat16 constructed from the unpacked sign, a scaled exponent, and the 1-bit mantissa.
Behavior:
- Validates `nbit == 4`, `dtype == "bfloat16"`, and `val.dtype == "uint8"` (AssertionError if violated).
- Extracts the 4-bit field at position `pos` (fields are packed consecutively in `val`).
......@@ -43,14 +43,14 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
and clamps the result to the 8-bit exponent range (0..255).
- Assembles a 16-bit bfloat16 bit pattern from (sign, biased-and-scaled-exponent, mantissa) and
returns it reinterpreted as `bfloat16`.
Parameters:
- nbit: must be 4 (width of the packed field).
- val: uint8 expression containing packed fields.
- pos: index of the field within `val` (0-based); used to compute the bit shift.
- scale: exponent-scale to add to the converted exponent (treated as an unsigned integer expression).
- dtype: must be "bfloat16".
Returns:
- A tir.PrimExpr of dtype "bfloat16" representing the decoded and scaled value.
"""
......@@ -75,16 +75,16 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True):
"""
Convert two float32 values to bfloat16 and pack them into a single uint32.
The two inputs v0 and v1 (float32 PrimExpr) are reinterpreted as uint32 bit patterns, optionally rounded to nearest-even
by adding a rounding bias, then truncated to their upper 16 bits (bfloat16 representation). The two 16-bit results are
packed into a uint32 with v0 in the lower 16 bits and v1 in the upper 16 bits.
Parameters:
v0 (tir.PrimExpr): First float32 value to convert and pack.
v1 (tir.PrimExpr): Second float32 value to convert and pack.
round_to_even (bool): If True, apply round-to-nearest-even bias before truncation (default True).
Returns:
tir.PrimExpr: A uint32 PrimExpr containing the packed bfloat16 representations (v0 low 16 bits, v1 high 16 bits).
"""
......
......@@ -76,7 +76,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
Returns:
_type_: _description_
Example:
qweight = torch.randint(0, 127, (10, 10), dtype=torch.int8).cuda()
interleave_weight(qweight, 4, "float16")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment