"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
),"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
assertkv_group==1,'here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)'
assertkv_group==1,(
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI=block_I
BI=block_I
NI=tilelang.cdiv(topk,block_I)
NI=tilelang.cdiv(topk,block_I)
assertNI%2==0,'NI should be a multiple of 2'
assertNI%2==0,"NI should be a multiple of 2"
D=dim
D=dim
D_tail=tail_dim
D_tail=tail_dim
KV_stride=kv_stride
KV_stride=kv_stride
ifhead_kv>64:
ifhead_kv>64:
asserthead_kv%64==0,'head_kv should be a multiple of 64'
asserthead_kv%64==0,"head_kv should be a multiple of 64"
assertq_start_index_s>kv_stride,"If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)"
assertq_start_index_s>kv_stride,(
"If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)"
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
- A: dense input of shape (M, K) with dtype `in_dtype`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- C: output of shape (M, N) with dtype `out_dtype`.
The generated kernel supports two dequantization paths:
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
Important behavior and requirements:
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
Parameters that alter kernel layout/behavior (brief):
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- num_stages: number of software pipeline stages for the K-loop.
- threads: number of threads used per kernel block.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
Returns:
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
"""
"""
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
- A: dense input of shape (M, K) with dtype `in_dtype`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- C: output of shape (M, N) with dtype `out_dtype`.
The generated kernel supports two dequantization paths:
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
Important behavior and requirements:
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
Parameters that alter kernel layout/behavior (brief):
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- num_stages: number of software pipeline stages for the K-loop.
- threads: number of threads used per kernel block.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
Returns:
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
"""
num_elems_per_byte=8//num_bits
num_elems_per_byte=8//num_bits
storage_dtype="uint8"
storage_dtype="uint8"
...
@@ -189,8 +189,7 @@ def matmul(M,
...
@@ -189,8 +189,7 @@ def matmul(M,
# Finally, store the dequantized data to shared memory.
# Finally, store the dequantized data to shared memory.
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
- Pipelines over K in chunks of `block_K` for `num_stages` stages:
- Pipelines over K in chunks of `block_K` for `num_stages` stages:
- Loads A and packed B tiles into shared memory.
- Loads A and packed B tiles into shared memory.
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
- Performs a GEMM accumulating into C_local with B transposed.
- Performs a GEMM accumulating into C_local with B transposed.
- Stores the accumulated block from C_local back to the global output C via C_shared.
- Stores the accumulated block from C_local back to the global output C via C_shared.
Parameters:
Parameters:
- A: input tile of shape (M, K) with dtype `in_dtype`.
- A: input tile of shape (M, K) with dtype `in_dtype`.
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
- C: output tensor of shape (M, N) with dtype `out_dtype`.
- C: output tensor of shape (M, N) with dtype `out_dtype`.
Side effects:
Side effects:
- Writes the computed output block into the global tensor `C`.
- Writes the computed output block into the global tensor `C`.
- Uses and updates shared memory buffers and per-thread accumulators.
- Uses and updates shared memory buffers and per-thread accumulators.
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
The generated kernel accepts:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- A: dense matrix with element type `in_dtype`.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- An assertion enforces that K % (block_K * split) == 0.
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors:
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
The generated kernel accepts:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- A: dense matrix with element type `in_dtype`.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- An assertion enforces that K % (block_K * split) == 0.
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors:
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype.
Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype.