"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "200a1086b4963c6b8e6098fe293716ca72af6af5"
Unverified Commit e3a80b70 authored by coderabbitai[bot]'s avatar coderabbitai[bot] Committed by GitHub
Browse files

📝 Add docstrings to `mxfp4` (#732)

* 📝 Add docstrings to `mxfp4`

Docstrings generation was requested by @LeiWang1999.

* https://github.com/tile-ai/tilelang/pull/725#issuecomment-3191656561



The following files were modified:

* `examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py`
* `examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py`
* `examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py`
* `examples/dequantize_gemm/utils.py`
* `examples/gemm/example_gemm_autotune.py`
* `tilelang/intrinsics/utils.py`
* `tilelang/language/__init__.py`
* `tilelang/language/utils.py`
* `tilelang/quantize/mxfp.py`
* `tilelang/quantize/quantization.py`

* [Lint] More accurate docstring

* [Lint]

---------
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: default avatartzj-fxz <tzjfxz@gmail.com>
parent 24603e4a
...@@ -82,6 +82,39 @@ def bitnet_158_int8xint2_prefill( ...@@ -82,6 +82,39 @@ def bitnet_158_int8xint2_prefill(
warp_col_tiles=32, warp_col_tiles=32,
chunk=64, chunk=64,
): ):
"""
Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C.
The returned prim_func expects:
- A: shape (M, K) with dtype `in_dtype` ("float16" or "int8").
- B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte).
- C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32").
Details:
- Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter.
- Tiling parameters:
- block_row_warps, block_col_warps: number of warps per block in row/col.
- warp_row_tiles, warp_col_tiles: tiles per warp.
- chunk: K-sized chunk per block (block_K).
- micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == "int32").
- Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior.
- Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values.
Parameters:
M, N, K (int): Global matrix dimensions.
in_dtype (str): Input and decoded B element dtype; "float16" or "int8".
out_dtype (str): Output C dtype; one of "float16", "float32", "int32".
accum_dtype (str): Accumulator dtype used by MMA (e.g., "int32").
fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used).
block_row_warps (int): Warps in block row dimension.
block_col_warps (int): Warps in block column dimension.
warp_row_tiles (int): Tiles per warp in row dimension.
warp_col_tiles (int): Tiles per warp in column dimension.
chunk (int): K-length per block (block_K).
Returns:
T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution.
"""
assert in_dtype in [ assert in_dtype in [
"float16", "float16",
"int8", "int8",
...@@ -152,6 +185,23 @@ def bitnet_158_int8xint2_prefill( ...@@ -152,6 +185,23 @@ def bitnet_158_int8xint2_prefill(
B: T.Buffer(B_shape, storage_dtype), B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Buffer((M, N), out_dtype),
): ):
"""
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.
This kernel:
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.
Parameters:
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).
Side effects:
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
"""
with T.Kernel( with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(N, block_N),
T.ceildiv(M, block_M), T.ceildiv(M, block_M),
......
...@@ -8,6 +8,21 @@ from utils import torch_convert_bit_twiddling, torch_convert ...@@ -8,6 +8,21 @@ from utils import torch_convert_bit_twiddling, torch_convert
def get_configs(): def get_configs():
"""
Return a list of tuning configuration dictionaries for the autotuned matmul kernel.
Each dictionary is a single combination (Cartesian product) of the following parameters:
- block_M: tile size for M dimension (one of 64, 128, 256)
- block_N: tile size for N dimension (one of 64, 128, 256)
- block_K: tile size for K dimension
- num_stages: pipeline stages for K-loop (0 or 2)
- threads: number of threads to launch (128, 256, or 512)
- split: K-splitting factor (1 or 2)
Returns:
list[dict]: List of configuration dicts usable by the autotuner, where each dict maps
the parameter name to its chosen value.
"""
import itertools import itertools
iter_params = dict( iter_params = dict(
block_M=[64, 128, 256], block_M=[64, 128, 256],
...@@ -45,6 +60,35 @@ def matmul(M, ...@@ -45,6 +60,35 @@ def matmul(M,
num_stages=2, num_stages=2,
threads=256, threads=256,
split=1): split=1):
"""
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
- A: dense input of shape (M, K) with dtype `in_dtype`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- C: output of shape (M, N) with dtype `out_dtype`.
The generated kernel supports two dequantization paths:
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
Important behavior and requirements:
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
Parameters that alter kernel layout/behavior (brief):
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- num_stages: number of software pipeline stages for the K-loop.
- threads: number of threads used per kernel block.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
Returns:
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
"""
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = "uint8"
...@@ -60,6 +104,9 @@ def matmul(M, ...@@ -60,6 +104,9 @@ def matmul(M,
from tilelang.quantize import get_mxfp_intrin_group from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling # fast_dequant_bf16_fp4_twiddling
# It requires that the 2 consecutive uint8 elements (16bits) contains 4 fp4 elements in a bit-twiddling way.
# The bit-twiddling way is shown here: The pair (x,y) shows that the bit in this position is the y-th bit of the x-th fp4 element.
# (0,0)(3,0)(3,3)(1,0)(3,1)(3,2)(2,0)(0,1)(0,2)(0,3)(1,1)(1,2)(1,3)(2,1)(2,2)(2,3)
mxfp_intrin_info = get_mxfp_intrin_group( mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype, out_dtype=in_dtype,
source_format=source_format, source_format=source_format,
...@@ -75,6 +122,20 @@ def matmul(M, ...@@ -75,6 +122,20 @@ def matmul(M,
import_source = import_source import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin.
This function validates the requested input/output datatypes and returns a TileLang `@T.macro` named `fast_dequant_bf16_fp4_twiddling` which:
- Loads compressed FP4 bytes from a shared buffer into per-thread local registers (vectorized loads).
- Invokes an external dequantization routine (via `T.call_extern`) to expand the packed FP4 values into BF16 in registers.
- Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel.
Notes and preconditions:
- Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`.
- The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel.
- The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly.
- The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout.
"""
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in ["bfloat16"]
...@@ -86,6 +147,23 @@ def matmul(M, ...@@ -86,6 +147,23 @@ def matmul(M,
@T.macro @T.macro
def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared): def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared):
# import fast_dequantize plugin # import fast_dequantize plugin
"""
Fast dequantization kernel routine that converts packed FP4 values in shared memory to BF16 and writes the results back into a shared dequantized buffer.
This function is intended to run inside a tiled GPU kernel: each thread loads a small packed segment from the quantized shared buffer `B_shared` into a per-thread local register buffer, calls an external dequantization routine (provided by the runtime plugin imported from `import_source` and identified by `func_name`) to expand the packed values to BF16 in a per-thread local output buffer, and stores the expanded values into `B_dequantize_shared`. It performs vectorized per-thread loads and stores and is sized according to the surrounding kernel's tiling and threading parameters.
Parameters:
B_shared: Shared-memory buffer containing packed quantized values (packed FP4 layout).
B_dequantize_shared: Shared-memory buffer to receive dequantized BF16 values (written in-place by this routine).
Side effects:
- Imports the external dequantization plugin via `import_source` and invokes `func_name`.
- Writes dequantized BF16 results into `B_dequantize_shared`.
Notes:
- This routine expects the surrounding kernel to define and provide the tiling/threading constants (e.g., thread count, local buffer sizes, block dimensions) and the runtime plugin identifiers (`import_source`, `func_name`).
- No value is returned; results are produced by mutation of `B_dequantize_shared`.
"""
T.import_source(import_source) T.import_source(import_source)
tx = T.get_thread_binding() tx = T.get_thread_binding()
...@@ -117,11 +195,51 @@ def matmul(M, ...@@ -117,11 +195,51 @@ def matmul(M,
return fast_dequant_bf16_fp4_twiddling return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16.
The returned macro (named `simple_dequant_bf16_fp4`) expects B_shared and B_dequantize_shared buffers (shapes and a few loop/constant names like
`B_shared_shape`, `B_dequantize_shared_shape`, `storage_dtype`, `out_dtype`, `num_bits`, `num_elems_per_byte`, `block_N`, and `block_K`) to be available in the surrounding TIR scope. It:
- Unpacks 4-bit FP values from the packed uint8 representation in B_shared.
- Converts each 4-bit value to a bfloat16 element using an internal helper `_tir_u8_to_f4_to_bf16`.
- Writes the dequantized bfloat16 block into B_dequantize_shared.
Constraints:
- Supports only in_dtype="fp4" and out_dtype="bfloat16".
- The helper assumes nbit == 4 and produces bfloat16 values.
- The macro uses a fixed test-scale of 0 (no per-element scaling) as written.
Returns:
A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16.
"""
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in ["bfloat16"]
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr,
scale: tir.PrimExpr, dtype: str): scale: tir.PrimExpr, dtype: str):
"""
Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value.
This helper extracts the 4-bit field located at the bit position `pos` within the
byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an
exponent `scale` offset to align it with bfloat16 exponent bias, clamps the
resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern.
Parameters:
nbit (int): Number of bits in the packed element; must be 4.
val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
dtype (str): Target dtype string; must be "bfloat16".
Returns:
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
Notes:
- The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8".
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
bit fields and clamps the computed exponent to fit into 8 bits.
"""
assert nbit == 4 assert nbit == 4
assert dtype == "bfloat16" assert dtype == "bfloat16"
assert val.dtype == "uint8" assert val.dtype == "uint8"
...@@ -142,6 +260,21 @@ def matmul(M, ...@@ -142,6 +260,21 @@ def matmul(M,
@T.macro @T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared): def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared):
"""
Dequantize a packed FP4 uint8 shared buffer into BF16 and store the result into a shared dequantized buffer.
This helper:
- Loads B_shared into a local fragment, converts each packed FP4 element to BF16 using `_tir_u8_to_f4_to_bf16`, and writes the dequantized values into B_dequantize_shared.
- Iterates in parallel over the logical block columns (block_N) and block_K, unpacking elements from bytes using `num_elems_per_byte`.
- Uses a fixed scale of 0 in the conversion (placeholder for testing); `num_bits` and `num_elems_per_byte` are expected to be available from the enclosing scope.
Parameters:
B_shared: shared-memory buffer containing packed FP4 data (uint8-packed).
B_dequantize_shared: shared-memory buffer to receive BF16 dequantized values.
Side effects:
Writes dequantized BF16 values into B_dequantize_shared. No return value.
"""
B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
T.copy(B_shared, B_local) T.copy(B_shared, B_local)
...@@ -163,6 +296,29 @@ def matmul(M, ...@@ -163,6 +296,29 @@ def matmul(M,
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
"""
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
- Pipelines over K in chunks of `block_K` for `num_stages` stages:
- Loads A and packed B tiles into shared memory.
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
- Performs a GEMM accumulating into C_local with B transposed.
- Stores the accumulated block from C_local back to the global output C via C_shared.
Parameters:
- A: input tile of shape (M, K) with dtype `in_dtype`.
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
- C: output tensor of shape (M, N) with dtype `out_dtype`.
Side effects:
- Writes the computed output block into the global tensor `C`.
- Uses and updates shared memory buffers and per-thread accumulators.
No value is returned.
"""
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
...@@ -194,6 +350,19 @@ def matmul(M, ...@@ -194,6 +350,19 @@ def matmul(M,
def ref_program_twiddling(A, qB): def ref_program_twiddling(A, qB):
"""
Compute reference BF16 matrix multiply using bit-twiddled FP4 quantized B.
Converts qB (a bit-twiddled, packed FP4 representation of matrix B) back to floating,
performs C = A @ B^T in full precision, and returns the result converted to bfloat16.
Parameters:
A (torch.Tensor): Left operand with shape (M, K). Treated as floating-point (converted to torch.float for compute).
qB (torch.Tensor): Bit-twiddled, packed FP4 representation of B (quantized). Shape corresponds to B's packed layout.
Returns:
torch.Tensor: Result matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
...@@ -202,6 +371,18 @@ def ref_program_twiddling(A, qB): ...@@ -202,6 +371,18 @@ def ref_program_twiddling(A, qB):
def ref_program_simple(A, qB): def ref_program_simple(A, qB):
"""
Compute a reference BF16 matrix multiply using a simple (non-twiddled) dequantization of qB.
Converts the quantized tensor `qB` to full-precision values via `torch_convert`, computes C = A @ B^T in float32, and casts the result to bfloat16 before returning.
Parameters:
A (torch.Tensor): Left input matrix with shape (M, K).
qB (torch.Tensor): Quantized representation of the right matrix; expected to be compatible with `torch_convert` and represent a matrix whose transpose will be multiplied by A.
Returns:
torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N).
"""
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert(qB) B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
...@@ -210,6 +391,22 @@ def ref_program_simple(A, qB): ...@@ -210,6 +391,22 @@ def ref_program_simple(A, qB):
def main(m=256, n=256, k=256, fast_dequant=True, tune=False): def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
"""
Run and benchmark the tiled, optionally autotuned FP4->BF16 GEMM kernel and validate results against a PyTorch reference.
This function builds a matmul kernel (either with autotuning or fixed tiling), obtains a profiler, validates numerical correctness against the appropriate reference implementation (bit-twiddled fast dequantization or simple dequantization), and runs a benchmark that prints measured latency (ms) and effective TFLOPs.
Parameters:
m (int): Number of rows of A and output C (default 256).
n (int): Number of columns of B and output C (default 256).
k (int): Inner dimension (columns of A, rows of B) (default 256).
fast_dequant (bool): If True use the fast twiddling dequantization path and validate against the twiddling reference; otherwise use the simple dequant path (default True).
tune (bool): If True build the kernel with autotuning configurations; if False use a fixed tiling and threading configuration for reproducible benchmarking (default False).
Side effects:
- Prints latency and TFLOPs to stdout.
- Raises an assertion via the profiler if the kernel's outputs do not match the chosen reference within the tolerances (rtol=0.01, atol=0.01).
"""
total_flops = 2 * m * n * k total_flops = 2 * m * n * k
if tune: if tune:
kernel = matmul( kernel = matmul(
......
...@@ -9,6 +9,27 @@ from utils import torch_convert_bit_twiddling, torch_convert ...@@ -9,6 +9,27 @@ from utils import torch_convert_bit_twiddling, torch_convert
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str): dtype: str):
"""
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
Parameters:
nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16").
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
assert nbit == 4 assert nbit == 4
assert dtype == "bfloat16" assert dtype == "bfloat16"
assert val.dtype == "uint8" assert val.dtype == "uint8"
...@@ -29,6 +50,20 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -29,6 +50,20 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
def get_configs(): def get_configs():
"""
Generate a list of hyperparameter configuration dictionaries for tuning.
Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
'num_stages', 'threads', and 'split'. The function returns the Cartesian
product of the parameter value lists:
- block_M, block_N, block_K: tiling sizes (64, 128, 256)
- num_stages: pipeline stages (0, 2)
- threads: thread counts (128, 256, 512)
- split: K-splitting factor (1, 2)
Returns:
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools import itertools
iter_params = dict( iter_params = dict(
block_M=[64, 128, 256], block_M=[64, 128, 256],
...@@ -61,7 +96,43 @@ def matmul(M, ...@@ -61,7 +96,43 @@ def matmul(M,
num_stages=2, num_stages=2,
threads=256, threads=256,
split=1): split=1):
"""
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = "uint8"
QK = K // num_elems_per_byte QK = K // num_elems_per_byte
...@@ -90,6 +161,20 @@ def matmul(M, ...@@ -90,6 +161,20 @@ def matmul(M,
import_source = import_source import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
- Loads packed FP4 elements from B_shared into per-thread local registers.
- Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values.
- Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two).
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in ["bfloat16"]
...@@ -101,6 +186,30 @@ def matmul(M, ...@@ -101,6 +186,30 @@ def matmul(M,
@T.macro @T.macro
def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k): def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k):
# import fast_dequantize plugin # import fast_dequantize plugin
"""
Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16
in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4,
applying per-block scale factors from Scale.
This routine is a tiled, thread-parallel helper that:
- Imports and calls an external dequantization function (via `import_source`/`func_name`)
to expand compressed uint8-packed FP4 values into BF16 fragments in-thread.
- Loads the corresponding per-block scale entry, interprets it as an exponent bias
(applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor.
- Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place.
Parameters:
- B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout).
- B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values.
- Scale: per-block scale tensor; entries are interpreted such that the multiplicative scale
= 2^(Scale - 127).
- k: block index along the K dimension used to select the appropriate Scale entries.
Side effects:
- Mutates B_dequantize_shared in shared memory.
- Calls an external intrinsic function (must be provided by the environment via `import_source`
and `func_name`) to perform the low-level unpacking/dequantization.
"""
T.import_source(import_source) T.import_source(import_source)
tx = T.get_thread_binding() tx = T.get_thread_binding()
...@@ -146,11 +255,38 @@ def matmul(M, ...@@ -146,11 +255,38 @@ def matmul(M,
return fast_dequant_bf16_fp4_twiddling return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper.
"""
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in ["bfloat16"]
@T.macro @T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k): def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
"""
Dequantizes a packed 4-bit (FP4) block from B_shared into BF16 values in B_dequantize_shared using per-element scale exponents.
Per-element behavior:
- Reads packed 4-bit entries from B_shared (uint8 storage, multiple nibbles per byte).
- Uses Scale to obtain an exponent term (stored as uint8) and reconstructs BF16 values via _tir_u8_to_f4_to_bf16.
- Writes the dequantized BF16 block into B_dequantize_shared.
Parameters:
- B_shared: shared-memory buffer holding packed 4-bit values (uint8-packed layout).
- B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results.
- Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element.
- k: current block index along the K dimension (used to select the appropriate slice of Scale).
Side effects:
- Mutates B_dequantize_shared by storing the dequantized BF16 fragment.
"""
B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
...@@ -177,6 +313,17 @@ def matmul(M, ...@@ -177,6 +313,17 @@ def matmul(M,
Scale: T.Tensor(Scale_shape, storage_dtype), Scale: T.Tensor(Scale_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
"""
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C.
"""
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
...@@ -210,6 +357,19 @@ def matmul(M, ...@@ -210,6 +357,19 @@ def matmul(M,
def ref_program_twiddling(A, qB, Scale): def ref_program_twiddling(A, qB, Scale):
"""
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Parameters:
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
...@@ -221,6 +381,21 @@ def ref_program_twiddling(A, qB, Scale): ...@@ -221,6 +381,21 @@ def ref_program_twiddling(A, qB, Scale):
def ref_program_simple(A, qB, Scale): def ref_program_simple(A, qB, Scale):
"""
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Parameters:
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
- qB: Quantized representation of B accepted by `torch_convert`.
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
Returns:
- 2D bfloat16 tensor C containing the matrix product A · B^T.
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert(qB) B = torch_convert(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
...@@ -232,6 +407,22 @@ def ref_program_simple(A, qB, Scale): ...@@ -232,6 +407,22 @@ def ref_program_simple(A, qB, Scale):
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False): def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
"""
Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS.
Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS.
Parameters:
m (int): Number of rows of A / output rows. Default 256.
n (int): Number of columns of B / output columns. Default 256.
k (int): Reduction dimension. Default 256.
scale_size (int): Size of the per-block scale vector used for dequantization. Default 32.
fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True.
tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False.
Returns:
None
"""
total_flops = 2 * m * n * k total_flops = 2 * m * n * k
if tune: if tune:
......
...@@ -2,6 +2,20 @@ import torch ...@@ -2,6 +2,20 @@ import torch
def torch_convert_bit_twiddling(tensor): def torch_convert_bit_twiddling(tensor):
"""
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.
This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.
Parameters:
tensor (torch.Tensor): 2-D input tensor with dtype `torch.uint8`. Shape (N, K).
Returns:
torch.Tensor: New tensor of dtype `torch.bfloat16` with shape (N, K*2), where each input column pair produces two bf16 output columns.
Raises:
AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`.
"""
def _convert(val0, val1, pos) -> torch.bfloat16: def _convert(val0, val1, pos) -> torch.bfloat16:
assert val0.dtype == torch.uint8 assert val0.dtype == torch.uint8
...@@ -37,6 +51,19 @@ def torch_convert_bit_twiddling(tensor): ...@@ -37,6 +51,19 @@ def torch_convert_bit_twiddling(tensor):
def torch_convert(tensor, scale_size=None, Scale=None): def torch_convert(tensor, scale_size=None, Scale=None):
"""
Decode a 2D uint8 tensor into a 2D bfloat16 tensor by expanding each byte into two bf16 values using a 4-bit (nibble) encoding.
Each input byte holds two 4-bit encoded values (low and high nibble). For each nibble this function derives sign/scale bits, a 3-bit exponent fragment and a 1-bit mantissa fragment, assembles a 16-bit bf16 pattern, and returns the resulting tensor with shape (N, K*2) and dtype torch.bfloat16 on the same device as the input.
Parameters:
tensor (torch.Tensor): 2D tensor of dtype torch.uint8 and shape (N, K). Each byte contains two encoded 4-bit entries that become two bf16 values.
scale_size (int, optional): If provided, controls how elements of the optional Scale tensor are indexed. When supplied, per-output-element scaling is applied to the exponent using Scale.
Scale (torch.Tensor, optional): A 2D tensor used to supply per-element integer scale adjustments to the exponent. If scale_size is provided, the scale used for output element (i, j) is Scale[i][j // scale_size].
Returns:
torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values.
"""
def _convert(val, pos, scale=None): def _convert(val, pos, scale=None):
assert val.dtype == torch.uint8 assert val.dtype == torch.uint8
...@@ -67,6 +94,15 @@ def torch_convert(tensor, scale_size=None, Scale=None): ...@@ -67,6 +94,15 @@ def torch_convert(tensor, scale_size=None, Scale=None):
def print_bit(name, val): def print_bit(name, val):
"""
Print the 32-bit binary representation of a CPU scalar extracted from a PyTorch tensor.
Converts `val` to CPU, reads its Python scalar with `.item()`, formats it as a 32-bit binary string, and prints it prefixed by `name`.
Parameters:
name (str): Label printed before the binary representation.
val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown.
"""
val_cpu = val.cpu().item() val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}' binary_repr = f'{val_cpu:032b}'
print(name, binary_repr) print(name, binary_repr)
...@@ -11,10 +11,40 @@ import torch ...@@ -11,10 +11,40 @@ import torch
def ref_program(A, B): def ref_program(A, B):
"""
Compute the matrix product of A and the transpose of B.
A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes.
"""
return A @ B.T return A @ B.T
def get_configs(M, N, K, with_roller=False, topk=20): def get_configs(M, N, K, with_roller=False, topk=20):
"""
Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply.
When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended
configurations (device-specific TensorCore-friendly tilings). Each returned dict contains:
- block_M, block_N, block_K: tile sizes
- num_stages: pipeline staging (0 means no explicit staging)
- thread_num: total threads used for the block
- enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling)
When with_roller is False this returns the Cartesian product of a fixed set of candidate
parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag.
Parameters:
M, N, K (int): GEMM dimensions used to generate valid tile sizes.
with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints;
otherwise use a predefined candidate grid.
topk (int): Maximum number of roller hints to request when with_roller is True.
Returns:
List[dict]: A list of configuration dictionaries as described above.
Raises:
ValueError: if with_roller is True but the roller returns no hints.
"""
if with_roller: if with_roller:
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
......
...@@ -76,6 +76,19 @@ def mfma_store_index_map(thread_id, local_id): ...@@ -76,6 +76,19 @@ def mfma_store_index_map(thread_id, local_id):
def get_mma_micro_size(dtype: Literal["float16", "int8"]): def get_mma_micro_size(dtype: Literal["float16", "int8"]):
# TODO(lei): FP8 related precision support. # TODO(lei): FP8 related precision support.
# Basic Tensor Core Matrix Multiply operation Unit # Basic Tensor Core Matrix Multiply operation Unit
"""
Return the MMA (Tensor Core) micro-tile dimensions for a given data type.
This function returns the micro tile sizes (x, y, k) used by MMA/Tensor Core operations.
- x: tile width in the output/result dimension
- y: tile height in the output/result dimension
- k: tile depth in the reduction/K dimension
Accepted dtype strings include "float16", "int8" and some FP8 identifiers ("float8_e4m3", "float8_e5m2"). For FP8 and int8 types the reduction depth (`k`) is 32; for float16 it is 16.
Returns:
tuple[int, int, int]: (micro_size_x, micro_size_y, micro_size_k)
"""
micro_size_x = micro_size_y = 16 micro_size_x = micro_size_y = 16
micro_size_k = 16 micro_size_k = 16
if dtype in {"float8_e4m3", "float8_e5m2", "int8"}: if dtype in {"float8_e4m3", "float8_e5m2", "int8"}:
......
...@@ -17,7 +17,6 @@ from .proxy import ( ...@@ -17,7 +17,6 @@ from .proxy import (
make_tensor, # noqa: F401 make_tensor, # noqa: F401
Buffer, # noqa: F401 Buffer, # noqa: F401
Tensor, # noqa: F401 Tensor, # noqa: F401
StridedTensor, # noqa: F401
FragmentBuffer, # noqa: F401 FragmentBuffer, # noqa: F401
SharedBuffer, # noqa: F401 SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401 LocalBuffer, # noqa: F401
...@@ -73,6 +72,16 @@ from .utils import index_to_coordinates # noqa: F401 ...@@ -73,6 +72,16 @@ from .utils import index_to_coordinates # noqa: F401
def symbolic(name: str, dtype: str = "int32"): def symbolic(name: str, dtype: str = "int32"):
"""
Create a TIR symbolic variable.
Parameters:
name (str): Identifier for the variable in generated TIR.
dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32".
Returns:
tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
"""
return tir.Var(name, dtype) return tir.Var(name, dtype)
......
...@@ -4,24 +4,16 @@ from tvm.tir import PrimExpr ...@@ -4,24 +4,16 @@ from tvm.tir import PrimExpr
def index_to_coordinates(index, shape) -> list[PrimExpr]: def index_to_coordinates(index, shape) -> list[PrimExpr]:
""" """
Convert a flat (linear) index to multi-dimensional coordinates for a given shape. Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
Example: Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates.
shape = (4, 5, 6)
index = 53 Parameters:
index_to_coordinates(53, (4, 5, 6)) -> [1, 3, 5] index (int or PrimExpr): The flat index to convert.
# Explanation: shape (Sequence[int]): The extents of each dimension (length >= 1).
# 53 // (5*6) = 1 (1st coordinate)
# 53 % (5*6) = 23
# 23 // 6 = 3 (2nd coordinate)
# 23 % 6 = 5 (3rd coordinate)
Args:
index (int): The flat index to convert.
shape (tuple or list of int): The shape of the multi-dimensional array.
Returns: Returns:
list: A list of coordinates corresponding to each dimension. list[PrimExpr]: Coordinates for each dimension in the same order as `shape`.
""" """
coordinates = [] coordinates = []
dims = len(shape) dims = len(shape)
...@@ -34,18 +26,29 @@ def index_to_coordinates(index, shape) -> list[PrimExpr]: ...@@ -34,18 +26,29 @@ def index_to_coordinates(index, shape) -> list[PrimExpr]:
def linear_index(*args: PrimExpr) -> PrimExpr: def linear_index(*args: PrimExpr) -> PrimExpr:
""" """
Convert a list of coordinates to a flat (linear) index using strides. Compute a flat (linear) index from multi-dimensional coordinates and strides.
Usage examples: The function accepts a sequence of PrimExpr arguments where the first portion are coordinates
linear_index(i) -> i and the trailing portion are the corresponding strides. The number of strides must equal
linear_index(i, j) -> i * stride + j (number of coordinates - 1). The linear index is computed as:
linear_index(i, j, stride_j) -> i * stride_j + j
linear_index(i, j, k, stride_j, stride_k) linear = coords[0]
-> i * stride_j * stride_k + j * stride_k + k for each (coord, stride) in zip(coords[1:], strides):
linear = linear * stride + coord
Example for index = i * threads * local_size + tx * local_size + v:
Suppose you have i, tx, v as coordinates, and threads, local_size as strides: Examples:
linear_index(i, tx, v, threads, local_size) == i * threads * local_size + tx * local_size + v - linear_index(i) -> i
- linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed)
- linear_index(i, j, stride_j) -> i * stride_j + j
- linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k
- linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v
Raises:
ValueError: If called with no arguments, or if the number of strides is not one less than
the number of coordinates.
Returns:
PrimExpr: The computed linear index expression.
""" """
n = len(args) n = len(args)
if n == 0: if n == 0:
......
...@@ -56,9 +56,29 @@ def get_mxfp_intrin_group( ...@@ -56,9 +56,29 @@ def get_mxfp_intrin_group(
use_twiddling: bool = False, use_twiddling: bool = False,
) -> Dict[str, str]: ) -> Dict[str, str]:
""" """
This function is used to get the intrinsic group of the MXFP operation to avoid the overhead of fast decoding. Return metadata for an MXFP decoding intrinsic: function name and C source string.
MXFP is a type of logic operation that takes three inputs. The intrinsic group refers to the set of
intrinsic operations that can be performed on these inputs. This function retrieves and returns this group. Validates the requested output dtype, source format, and storage dtype, then constructs
a lookup key of the form `fp{source_bit}_to_{f16|bf16}` (appending `_twiddling` when
use_twiddling is True) to select the corresponding C source snippet and a matching
function name `decode_fp{source_bit}_to_{f16|bf16}` (also optionally suffixed with
`_twiddling`).
Parameters:
out_dtype: Target floating-point type for decoded values; either "float16" or "bfloat16".
source_format: Integer source representation; "int" or "uint".
source_bit: Bit width of the packed source format (e.g., 4).
storage_dtype: Underlying storage integer dtype (one of "int32", "int8", "uint8").
use_twiddling: When True, select the twiddling variant of the decoding intrinsic.
Returns:
A dict with:
- "func_name": the generated C function name string for the requested decode intrinsic.
- "c_source": the C source string for that intrinsic.
Raises:
AssertionError: if out_dtype, source_format, or storage_dtype are not supported.
KeyError: if the constructed key does not match any available C source implementation.
""" """
assert out_dtype in ["float16", "bfloat16" assert out_dtype in ["float16", "bfloat16"
], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'." ], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'."
......
...@@ -29,6 +29,31 @@ from tvm import tir ...@@ -29,6 +29,31 @@ from tvm import tir
# fmt: off # fmt: off
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str): dtype: str):
"""
Convert a packed 4-bit field stored in a uint8 into a bfloat16 value using an exponent scale.
This function expects a storage field of width `nbit == 4` packed into the 8-bit input `val` and returns
a bfloat16 constructed from the unpacked sign, a scaled exponent, and the 1-bit mantissa.
Behavior:
- Validates `nbit == 4`, `dtype == "bfloat16"`, and `val.dtype == "uint8"` (AssertionError if violated).
- Extracts the 4-bit field at position `pos` (fields are packed consecutively in `val`).
- Interprets the 4-bit field as: sign = bit3, exponent = bits1-2, mantissa = bit0.
- Converts the 2-bit exponent to bf16 exponent space by adding a bias of 126, adds `scale` to that exponent,
and clamps the result to the 8-bit exponent range (0..255).
- Assembles a 16-bit bfloat16 bit pattern from (sign, biased-and-scaled-exponent, mantissa) and
returns it reinterpreted as `bfloat16`.
Parameters:
- nbit: must be 4 (width of the packed field).
- val: uint8 expression containing packed fields.
- pos: index of the field within `val` (0-based); used to compute the bit shift.
- scale: exponent-scale to add to the converted exponent (treated as an unsigned integer expression).
- dtype: must be "bfloat16".
Returns:
- A tir.PrimExpr of dtype "bfloat16" representing the decoded and scaled value.
"""
assert nbit == 4 assert nbit == 4
assert dtype == "bfloat16" assert dtype == "bfloat16"
assert val.dtype == "uint8" assert val.dtype == "uint8"
...@@ -48,6 +73,21 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -48,6 +73,21 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
return val_bf16 return val_bf16
def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True): def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True):
"""
Convert two float32 values to bfloat16 and pack them into a single uint32.
The two inputs v0 and v1 (float32 PrimExpr) are reinterpreted as uint32 bit patterns, optionally rounded to nearest-even
by adding a rounding bias, then truncated to their upper 16 bits (bfloat16 representation). The two 16-bit results are
packed into a uint32 with v0 in the lower 16 bits and v1 in the upper 16 bits.
Parameters:
v0 (tir.PrimExpr): First float32 value to convert and pack.
v1 (tir.PrimExpr): Second float32 value to convert and pack.
round_to_even (bool): If True, apply round-to-nearest-even bias before truncation (default True).
Returns:
tir.PrimExpr: A uint32 PrimExpr containing the packed bfloat16 representations (v0 low 16 bits, v1 high 16 bits).
"""
mask = tir.const((1 << 16) - 1, "uint32") mask = tir.const((1 << 16) - 1, "uint32")
res = [] res = []
for data in [v0, v1]: for data in [v0, v1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment