Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C.
Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C.
The returned prim_func expects:
The returned prim_func expects:
- A: shape (M, K) with dtype `in_dtype` ("float16" or "int8").
- A: shape (M, K) with dtype `in_dtype` ("float16" or "int8").
- B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte).
- B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte).
- C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32").
- C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32").
Details:
Details:
- Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter.
- Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter.
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.
This kernel:
This kernel:
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.
Parameters:
Parameters:
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).
Side effects:
Side effects:
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
@@ -10,15 +10,15 @@ from utils import torch_convert_bit_twiddling, torch_convert
...
@@ -10,15 +10,15 @@ from utils import torch_convert_bit_twiddling, torch_convert
defget_configs():
defget_configs():
"""
"""
Return a list of tuning configuration dictionaries for the autotuned matmul kernel.
Return a list of tuning configuration dictionaries for the autotuned matmul kernel.
Each dictionary is a single combination (Cartesian product) of the following parameters:
Each dictionary is a single combination (Cartesian product) of the following parameters:
- block_M: tile size for M dimension (one of 64, 128, 256)
- block_M: tile size for M dimension (one of 64, 128, 256)
- block_N: tile size for N dimension (one of 64, 128, 256)
- block_N: tile size for N dimension (one of 64, 128, 256)
- block_K: tile size for K dimension
- block_K: tile size for K dimension
- num_stages: pipeline stages for K-loop (0 or 2)
- num_stages: pipeline stages for K-loop (0 or 2)
- threads: number of threads to launch (128, 256, or 512)
- threads: number of threads to launch (128, 256, or 512)
- split: K-splitting factor (1 or 2)
- split: K-splitting factor (1 or 2)
Returns:
Returns:
list[dict]: List of configuration dicts usable by the autotuner, where each dict maps
list[dict]: List of configuration dicts usable by the autotuner, where each dict maps
the parameter name to its chosen value.
the parameter name to its chosen value.
...
@@ -62,30 +62,30 @@ def matmul(M,
...
@@ -62,30 +62,30 @@ def matmul(M,
split=1):
split=1):
"""
"""
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
- A: dense input of shape (M, K) with dtype `in_dtype`.
- A: dense input of shape (M, K) with dtype `in_dtype`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- C: output of shape (M, N) with dtype `out_dtype`.
- C: output of shape (M, N) with dtype `out_dtype`.
The generated kernel supports two dequantization paths:
The generated kernel supports two dequantization paths:
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
Important behavior and requirements:
Important behavior and requirements:
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
Parameters that alter kernel layout/behavior (brief):
Parameters that alter kernel layout/behavior (brief):
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- num_stages: number of software pipeline stages for the K-loop.
- num_stages: number of software pipeline stages for the K-loop.
- threads: number of threads used per kernel block.
- threads: number of threads used per kernel block.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
Returns:
Returns:
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin.
Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin.
This function validates the requested input/output datatypes and returns a TileLang `@T.macro` named `fast_dequant_bf16_fp4_twiddling` which:
This function validates the requested input/output datatypes and returns a TileLang `@T.macro` named `fast_dequant_bf16_fp4_twiddling` which:
- Loads compressed FP4 bytes from a shared buffer into per-thread local registers (vectorized loads).
- Loads compressed FP4 bytes from a shared buffer into per-thread local registers (vectorized loads).
- Invokes an external dequantization routine (via `T.call_extern`) to expand the packed FP4 values into BF16 in registers.
- Invokes an external dequantization routine (via `T.call_extern`) to expand the packed FP4 values into BF16 in registers.
- Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel.
- Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel.
Notes and preconditions:
Notes and preconditions:
- Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`.
- Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`.
- The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel.
- The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel.
...
@@ -149,17 +149,17 @@ def matmul(M,
...
@@ -149,17 +149,17 @@ def matmul(M,
# import fast_dequantize plugin
# import fast_dequantize plugin
"""
"""
Fast dequantization kernel routine that converts packed FP4 values in shared memory to BF16 and writes the results back into a shared dequantized buffer.
Fast dequantization kernel routine that converts packed FP4 values in shared memory to BF16 and writes the results back into a shared dequantized buffer.
This function is intended to run inside a tiled GPU kernel: each thread loads a small packed segment from the quantized shared buffer `B_shared` into a per-thread local register buffer, calls an external dequantization routine (provided by the runtime plugin imported from `import_source` and identified by `func_name`) to expand the packed values to BF16 in a per-thread local output buffer, and stores the expanded values into `B_dequantize_shared`. It performs vectorized per-thread loads and stores and is sized according to the surrounding kernel's tiling and threading parameters.
This function is intended to run inside a tiled GPU kernel: each thread loads a small packed segment from the quantized shared buffer `B_shared` into a per-thread local register buffer, calls an external dequantization routine (provided by the runtime plugin imported from `import_source` and identified by `func_name`) to expand the packed values to BF16 in a per-thread local output buffer, and stores the expanded values into `B_dequantize_shared`. It performs vectorized per-thread loads and stores and is sized according to the surrounding kernel's tiling and threading parameters.
B_dequantize_shared: Shared-memory buffer to receive dequantized BF16 values (written in-place by this routine).
B_dequantize_shared: Shared-memory buffer to receive dequantized BF16 values (written in-place by this routine).
Side effects:
Side effects:
- Imports the external dequantization plugin via `import_source` and invokes `func_name`.
- Imports the external dequantization plugin via `import_source` and invokes `func_name`.
- Writes dequantized BF16 results into `B_dequantize_shared`.
- Writes dequantized BF16 results into `B_dequantize_shared`.
Notes:
Notes:
- This routine expects the surrounding kernel to define and provide the tiling/threading constants (e.g., thread count, local buffer sizes, block dimensions) and the runtime plugin identifiers (`import_source`, `func_name`).
- This routine expects the surrounding kernel to define and provide the tiling/threading constants (e.g., thread count, local buffer sizes, block dimensions) and the runtime plugin identifiers (`import_source`, `func_name`).
- No value is returned; results are produced by mutation of `B_dequantize_shared`.
- No value is returned; results are produced by mutation of `B_dequantize_shared`.
Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16.
Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16.
The returned macro (named `simple_dequant_bf16_fp4`) expects B_shared and B_dequantize_shared buffers (shapes and a few loop/constant names like
The returned macro (named `simple_dequant_bf16_fp4`) expects B_shared and B_dequantize_shared buffers (shapes and a few loop/constant names like
`B_shared_shape`, `B_dequantize_shared_shape`, `storage_dtype`, `out_dtype`, `num_bits`, `num_elems_per_byte`, `block_N`, and `block_K`) to be available in the surrounding TIR scope. It:
`B_shared_shape`, `B_dequantize_shared_shape`, `storage_dtype`, `out_dtype`, `num_bits`, `num_elems_per_byte`, `block_N`, and `block_K`) to be available in the surrounding TIR scope. It:
- Unpacks 4-bit FP values from the packed uint8 representation in B_shared.
- Unpacks 4-bit FP values from the packed uint8 representation in B_shared.
- Converts each 4-bit value to a bfloat16 element using an internal helper `_tir_u8_to_f4_to_bf16`.
- Converts each 4-bit value to a bfloat16 element using an internal helper `_tir_u8_to_f4_to_bf16`.
- Writes the dequantized bfloat16 block into B_dequantize_shared.
- Writes the dequantized bfloat16 block into B_dequantize_shared.
Constraints:
Constraints:
- Supports only in_dtype="fp4" and out_dtype="bfloat16".
- Supports only in_dtype="fp4" and out_dtype="bfloat16".
- The helper assumes nbit == 4 and produces bfloat16 values.
- The helper assumes nbit == 4 and produces bfloat16 values.
- The macro uses a fixed test-scale of 0 (no per-element scaling) as written.
- The macro uses a fixed test-scale of 0 (no per-element scaling) as written.
Returns:
Returns:
A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16.
A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16.
"""
"""
...
@@ -219,22 +219,22 @@ def matmul(M,
...
@@ -219,22 +219,22 @@ def matmul(M,
scale:tir.PrimExpr,dtype:str):
scale:tir.PrimExpr,dtype:str):
"""
"""
Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value.
Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value.
This helper extracts the 4-bit field located at the bit position `pos` within the
This helper extracts the 4-bit field located at the bit position `pos` within the
byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an
byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an
exponent `scale` offset to align it with bfloat16 exponent bias, clamps the
exponent `scale` offset to align it with bfloat16 exponent bias, clamps the
resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern.
resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern.
Parameters:
Parameters:
nbit (int): Number of bits in the packed element; must be 4.
nbit (int): Number of bits in the packed element; must be 4.
val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
dtype (str): Target dtype string; must be "bfloat16".
dtype (str): Target dtype string; must be "bfloat16".
Returns:
Returns:
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
Notes:
Notes:
- The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8".
- The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8".
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
Dequantize a packed FP4 uint8 shared buffer into BF16 and store the result into a shared dequantized buffer.
Dequantize a packed FP4 uint8 shared buffer into BF16 and store the result into a shared dequantized buffer.
This helper:
This helper:
- Loads B_shared into a local fragment, converts each packed FP4 element to BF16 using `_tir_u8_to_f4_to_bf16`, and writes the dequantized values into B_dequantize_shared.
- Loads B_shared into a local fragment, converts each packed FP4 element to BF16 using `_tir_u8_to_f4_to_bf16`, and writes the dequantized values into B_dequantize_shared.
- Iterates in parallel over the logical block columns (block_N) and block_K, unpacking elements from bytes using `num_elems_per_byte`.
- Iterates in parallel over the logical block columns (block_N) and block_K, unpacking elements from bytes using `num_elems_per_byte`.
- Uses a fixed scale of 0 in the conversion (placeholder for testing); `num_bits` and `num_elems_per_byte` are expected to be available from the enclosing scope.
- Uses a fixed scale of 0 in the conversion (placeholder for testing); `num_bits` and `num_elems_per_byte` are expected to be available from the enclosing scope.
Parameters:
Parameters:
B_shared: shared-memory buffer containing packed FP4 data (uint8-packed).
B_shared: shared-memory buffer containing packed FP4 data (uint8-packed).
B_dequantize_shared: shared-memory buffer to receive BF16 dequantized values.
B_dequantize_shared: shared-memory buffer to receive BF16 dequantized values.
Side effects:
Side effects:
Writes dequantized BF16 values into B_dequantize_shared. No return value.
Writes dequantized BF16 values into B_dequantize_shared. No return value.
"""
"""
...
@@ -298,7 +298,7 @@ def matmul(M,
...
@@ -298,7 +298,7 @@ def matmul(M,
):
):
"""
"""
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
...
@@ -307,16 +307,16 @@ def matmul(M,
...
@@ -307,16 +307,16 @@ def matmul(M,
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
- Performs a GEMM accumulating into C_local with B transposed.
- Performs a GEMM accumulating into C_local with B transposed.
- Stores the accumulated block from C_local back to the global output C via C_shared.
- Stores the accumulated block from C_local back to the global output C via C_shared.
Parameters:
Parameters:
- A: input tile of shape (M, K) with dtype `in_dtype`.
- A: input tile of shape (M, K) with dtype `in_dtype`.
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
- C: output tensor of shape (M, N) with dtype `out_dtype`.
- C: output tensor of shape (M, N) with dtype `out_dtype`.
Side effects:
Side effects:
- Writes the computed output block into the global tensor `C`.
- Writes the computed output block into the global tensor `C`.
- Uses and updates shared memory buffers and per-thread accumulators.
- Uses and updates shared memory buffers and per-thread accumulators.
Compute a reference BF16 matrix multiply using a simple (non-twiddled) dequantization of qB.
Compute a reference BF16 matrix multiply using a simple (non-twiddled) dequantization of qB.
Converts the quantized tensor `qB` to full-precision values via `torch_convert`, computes C = A @ B^T in float32, and casts the result to bfloat16 before returning.
Converts the quantized tensor `qB` to full-precision values via `torch_convert`, computes C = A @ B^T in float32, and casts the result to bfloat16 before returning.
Parameters:
Parameters:
A (torch.Tensor): Left input matrix with shape (M, K).
A (torch.Tensor): Left input matrix with shape (M, K).
qB (torch.Tensor): Quantized representation of the right matrix; expected to be compatible with `torch_convert` and represent a matrix whose transpose will be multiplied by A.
qB (torch.Tensor): Quantized representation of the right matrix; expected to be compatible with `torch_convert` and represent a matrix whose transpose will be multiplied by A.
Returns:
Returns:
torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N).
torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N).
Run and benchmark the tiled, optionally autotuned FP4->BF16 GEMM kernel and validate results against a PyTorch reference.
Run and benchmark the tiled, optionally autotuned FP4->BF16 GEMM kernel and validate results against a PyTorch reference.
This function builds a matmul kernel (either with autotuning or fixed tiling), obtains a profiler, validates numerical correctness against the appropriate reference implementation (bit-twiddled fast dequantization or simple dequantization), and runs a benchmark that prints measured latency (ms) and effective TFLOPs.
This function builds a matmul kernel (either with autotuning or fixed tiling), obtains a profiler, validates numerical correctness against the appropriate reference implementation (bit-twiddled fast dequantization or simple dequantization), and runs a benchmark that prints measured latency (ms) and effective TFLOPs.
Parameters:
Parameters:
m (int): Number of rows of A and output C (default 256).
m (int): Number of rows of A and output C (default 256).
n (int): Number of columns of B and output C (default 256).
n (int): Number of columns of B and output C (default 256).
k (int): Inner dimension (columns of A, rows of B) (default 256).
k (int): Inner dimension (columns of A, rows of B) (default 256).
fast_dequant (bool): If True use the fast twiddling dequantization path and validate against the twiddling reference; otherwise use the simple dequant path (default True).
fast_dequant (bool): If True use the fast twiddling dequantization path and validate against the twiddling reference; otherwise use the simple dequant path (default True).
tune (bool): If True build the kernel with autotuning configurations; if False use a fixed tiling and threading configuration for reproducible benchmarking (default False).
tune (bool): If True build the kernel with autotuning configurations; if False use a fixed tiling and threading configuration for reproducible benchmarking (default False).
Side effects:
Side effects:
- Prints latency and TFLOPs to stdout.
- Prints latency and TFLOPs to stdout.
- Raises an assertion via the profiler if the kernel's outputs do not match the chosen reference within the tolerances (rtol=0.01, atol=0.01).
- Raises an assertion via the profiler if the kernel's outputs do not match the chosen reference within the tolerances (rtol=0.01, atol=0.01).
Generate a list of hyperparameter configuration dictionaries for tuning.
Generate a list of hyperparameter configuration dictionaries for tuning.
Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
'num_stages', 'threads', and 'split'. The function returns the Cartesian
'num_stages', 'threads', and 'split'. The function returns the Cartesian
product of the parameter value lists:
product of the parameter value lists:
...
@@ -60,7 +60,7 @@ def get_configs():
...
@@ -60,7 +60,7 @@ def get_configs():
- num_stages: pipeline stages (0, 2)
- num_stages: pipeline stages (0, 2)
- threads: thread counts (128, 256, 512)
- threads: thread counts (128, 256, 512)
- split: K-splitting factor (1, 2)
- split: K-splitting factor (1, 2)
Returns:
Returns:
List[dict]: A list of configuration dictionaries covering all combinations.
List[dict]: A list of configuration dictionaries covering all combinations.
"""
"""
...
@@ -99,7 +99,7 @@ def matmul(M,
...
@@ -99,7 +99,7 @@ def matmul(M,
split=1):
split=1):
"""
"""
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
...
@@ -107,7 +107,7 @@ def matmul(M,
...
@@ -107,7 +107,7 @@ def matmul(M,
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
in_dtype (str): element type of A (e.g., "fp4" in this file).
...
@@ -129,7 +129,7 @@ def matmul(M,
...
@@ -129,7 +129,7 @@ def matmul(M,
- dequantizes B via the chosen path into a shared dequantized tile,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
- writes the final MxN block to the global output tensor.
Notes:
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
- Loads packed FP4 elements from B_shared into per-thread local registers.
- Loads packed FP4 elements from B_shared into per-thread local registers.
- Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values.
- Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values.
- Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two).
- Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two).
- Writes the scaled BF16 results into B_dequantize_shared.
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
...
@@ -194,21 +194,21 @@ def matmul(M,
...
@@ -194,21 +194,21 @@ def matmul(M,
Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16
Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16
in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4,
in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4,
applying per-block scale factors from Scale.
applying per-block scale factors from Scale.
This routine is a tiled, thread-parallel helper that:
This routine is a tiled, thread-parallel helper that:
- Imports and calls an external dequantization function (via `import_source`/`func_name`)
- Imports and calls an external dequantization function (via `import_source`/`func_name`)
to expand compressed uint8-packed FP4 values into BF16 fragments in-thread.
to expand compressed uint8-packed FP4 values into BF16 fragments in-thread.
- Loads the corresponding per-block scale entry, interprets it as an exponent bias
- Loads the corresponding per-block scale entry, interprets it as an exponent bias
(applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor.
(applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor.
- Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place.
- Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place.
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes:
Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results.
- B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results.
- Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element.
- Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element.
- k: current block index along the K dimension (used to select the appropriate slice of Scale).
- k: current block index along the K dimension (used to select the appropriate slice of Scale).
Side effects:
Side effects:
- Mutates B_dequantize_shared by storing the dequantized BF16 fragment.
- Mutates B_dequantize_shared by storing the dequantized BF16 fragment.
"""
"""
...
@@ -320,9 +320,9 @@ def matmul(M,
...
@@ -320,9 +320,9 @@ def matmul(M,
):
):
"""
"""
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors:
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
...
@@ -376,14 +376,14 @@ def matmul(M,
...
@@ -376,14 +376,14 @@ def matmul(M,
defref_program_twiddling(A,qB,Scale,Bias=None):
defref_program_twiddling(A,qB,Scale,Bias=None):
"""
"""
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Parameters:
Parameters:
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
Returns:
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Parameters:
Parameters:
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Parameters:
Parameters:
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
- qB: Quantized representation of B accepted by `torch_convert`.
- qB: Quantized representation of B accepted by `torch_convert`.
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
Returns:
Returns:
- 2D bfloat16 tensor C containing the matrix product A · B^T.
- 2D bfloat16 tensor C containing the matrix product A · B^T.
No in-place modification is performed on inputs (a local floating copy of B is scaled).
No in-place modification is performed on inputs (a local floating copy of B is scaled).
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS.
Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS.
Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS.
Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS.
Parameters:
Parameters:
m (int): Number of rows of A / output rows. Default 256.
m (int): Number of rows of A / output rows. Default 256.
n (int): Number of columns of B / output columns. Default 256.
n (int): Number of columns of B / output columns. Default 256.
scale_size (int): Size of the per-block scale vector used for dequantization. Default 32.
scale_size (int): Size of the per-block scale vector used for dequantization. Default 32.
fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True.
fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True.
tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False.
tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False.
Generate a list of hyperparameter configuration dictionaries for tuning.
Generate a list of hyperparameter configuration dictionaries for tuning.
Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
'num_stages', 'threads', and 'split'. The function returns the Cartesian
'num_stages', 'threads', and 'split'. The function returns the Cartesian
product of the parameter value lists:
product of the parameter value lists:
...
@@ -60,7 +60,7 @@ def get_configs():
...
@@ -60,7 +60,7 @@ def get_configs():
- num_stages: pipeline stages (0, 2)
- num_stages: pipeline stages (0, 2)
- threads: thread counts (128, 256, 512)
- threads: thread counts (128, 256, 512)
- split: K-splitting factor (1, 2)
- split: K-splitting factor (1, 2)
Returns:
Returns:
List[dict]: A list of configuration dictionaries covering all combinations.
List[dict]: A list of configuration dictionaries covering all combinations.
"""
"""
...
@@ -99,7 +99,7 @@ def matmul(M,
...
@@ -99,7 +99,7 @@ def matmul(M,
split=1):
split=1):
"""
"""
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
...
@@ -107,7 +107,7 @@ def matmul(M,
...
@@ -107,7 +107,7 @@ def matmul(M,
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
in_dtype (str): element type of A (e.g., "fp4" in this file).
...
@@ -129,7 +129,7 @@ def matmul(M,
...
@@ -129,7 +129,7 @@ def matmul(M,
- dequantizes B via the chosen path into a shared dequantized tile,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
- writes the final MxN block to the global output tensor.
Notes:
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
- Loads packed FP4 elements from B_shared into per-thread local registers.
- Loads packed FP4 elements from B_shared into per-thread local registers.
- Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values.
- Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values.
- Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two).
- Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two).
- Writes the scaled BF16 results into B_dequantize_shared.
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
...
@@ -194,21 +194,21 @@ def matmul(M,
...
@@ -194,21 +194,21 @@ def matmul(M,
Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16
Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16
in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4,
in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4,
applying per-block scale factors from Scale.
applying per-block scale factors from Scale.
This routine is a tiled, thread-parallel helper that:
This routine is a tiled, thread-parallel helper that:
- Imports and calls an external dequantization function (via `import_source`/`func_name`)
- Imports and calls an external dequantization function (via `import_source`/`func_name`)
to expand compressed uint8-packed FP4 values into BF16 fragments in-thread.
to expand compressed uint8-packed FP4 values into BF16 fragments in-thread.
- Loads the corresponding per-block scale entry, interprets it as an exponent bias
- Loads the corresponding per-block scale entry, interprets it as an exponent bias
(applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor.
(applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor.
- Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place.
- Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place.
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes:
Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results.
- B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results.
- Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element.
- Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element.
- k: current block index along the K dimension (used to select the appropriate slice of Scale).
- k: current block index along the K dimension (used to select the appropriate slice of Scale).
Side effects:
Side effects:
- Mutates B_dequantize_shared by storing the dequantized BF16 fragment.
- Mutates B_dequantize_shared by storing the dequantized BF16 fragment.
"""
"""
...
@@ -319,9 +319,9 @@ def matmul(M,
...
@@ -319,9 +319,9 @@ def matmul(M,
):
):
"""
"""
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors:
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
...
@@ -384,14 +384,14 @@ def matmul(M,
...
@@ -384,14 +384,14 @@ def matmul(M,
defref_program_twiddling(A,qB,Scale,Bias=None):
defref_program_twiddling(A,qB,Scale,Bias=None):
"""
"""
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Parameters:
Parameters:
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
Returns:
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Parameters:
Parameters:
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Parameters:
Parameters:
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
- qB: Quantized representation of B accepted by `torch_convert`.
- qB: Quantized representation of B accepted by `torch_convert`.
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
Returns:
Returns:
- 2D bfloat16 tensor C containing the matrix product A · B^T.
- 2D bfloat16 tensor C containing the matrix product A · B^T.
No in-place modification is performed on inputs (a local floating copy of B is scaled).
No in-place modification is performed on inputs (a local floating copy of B is scaled).
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS.
Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS.
Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS.
Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS.
Parameters:
Parameters:
m (int): Number of rows of A / output rows. Default 256.
m (int): Number of rows of A / output rows. Default 256.
n (int): Number of columns of B / output columns. Default 256.
n (int): Number of columns of B / output columns. Default 256.
scale_size (int): Size of the per-block scale vector used for dequantization. Default 32.
scale_size (int): Size of the per-block scale vector used for dequantization. Default 32.
fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True.
fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True.
tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False.
tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False.
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.
This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.
This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.
Decode a 2D uint8 tensor into a 2D bfloat16 tensor by expanding each byte into two bf16 values using a 4-bit (nibble) encoding.
Decode a 2D uint8 tensor into a 2D bfloat16 tensor by expanding each byte into two bf16 values using a 4-bit (nibble) encoding.
Each input byte holds two 4-bit encoded values (low and high nibble). For each nibble this function derives sign/scale bits, a 3-bit exponent fragment and a 1-bit mantissa fragment, assembles a 16-bit bf16 pattern, and returns the resulting tensor with shape (N, K*2) and dtype torch.bfloat16 on the same device as the input.
Each input byte holds two 4-bit encoded values (low and high nibble). For each nibble this function derives sign/scale bits, a 3-bit exponent fragment and a 1-bit mantissa fragment, assembles a 16-bit bf16 pattern, and returns the resulting tensor with shape (N, K*2) and dtype torch.bfloat16 on the same device as the input.
Parameters:
Parameters:
tensor (torch.Tensor): 2D tensor of dtype torch.uint8 and shape (N, K). Each byte contains two encoded 4-bit entries that become two bf16 values.
tensor (torch.Tensor): 2D tensor of dtype torch.uint8 and shape (N, K). Each byte contains two encoded 4-bit entries that become two bf16 values.
scale_size (int, optional): If provided, controls how elements of the optional Scale tensor are indexed. When supplied, per-output-element scaling is applied to the exponent using Scale.
scale_size (int, optional): If provided, controls how elements of the optional Scale tensor are indexed. When supplied, per-output-element scaling is applied to the exponent using Scale.
Scale (torch.Tensor, optional): A 2D tensor used to supply per-element integer scale adjustments to the exponent. If scale_size is provided, the scale used for output element (i, j) is Scale[i][j // scale_size].
Scale (torch.Tensor, optional): A 2D tensor used to supply per-element integer scale adjustments to the exponent. If scale_size is provided, the scale used for output element (i, j) is Scale[i][j // scale_size].
Returns:
Returns:
torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values.
torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values.
Compute the matrix product of A and the transpose of B.
Compute the matrix product of A and the transpose of B.
A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes.
A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes.
"""
"""
returnA@B.T
returnA@B.T
...
@@ -22,26 +22,26 @@ def ref_program(A, B):
...
@@ -22,26 +22,26 @@ def ref_program(A, B):
defget_configs(M,N,K,with_roller=False,topk=20):
defget_configs(M,N,K,with_roller=False,topk=20):
"""
"""
Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply.
Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply.
When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended
When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended
configurations (device-specific TensorCore-friendly tilings). Each returned dict contains:
configurations (device-specific TensorCore-friendly tilings). Each returned dict contains:
- block_M, block_N, block_K: tile sizes
- block_M, block_N, block_K: tile sizes
- num_stages: pipeline staging (0 means no explicit staging)
- num_stages: pipeline staging (0 means no explicit staging)
- thread_num: total threads used for the block
- thread_num: total threads used for the block
- enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling)
- enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling)
When with_roller is False this returns the Cartesian product of a fixed set of candidate
When with_roller is False this returns the Cartesian product of a fixed set of candidate
parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag.
parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag.
Parameters:
Parameters:
M, N, K (int): GEMM dimensions used to generate valid tile sizes.
M, N, K (int): GEMM dimensions used to generate valid tile sizes.
with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints;
with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints;
otherwise use a predefined candidate grid.
otherwise use a predefined candidate grid.
topk (int): Maximum number of roller hints to request when with_roller is True.
topk (int): Maximum number of roller hints to request when with_roller is True.
Returns:
Returns:
List[dict]: A list of configuration dictionaries as described above.
List[dict]: A list of configuration dictionaries as described above.
Raises:
Raises:
ValueError: if with_roller is True but the roller returns no hints.
ValueError: if with_roller is True but the roller returns no hints.