Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin.
...
...
@@ -131,13 +131,13 @@ def matmul(M,
- Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel.
Notes and preconditions:
- Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`.
- Asserts that `in_dtype == "fp4"` and `out_dtype == T.bfloat16`.
- The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel.
- The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly.
- The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout.
"""
assertin_dtypein["fp4"]
assertout_dtypein["bfloat16"]
assertout_dtypein[T.bfloat16]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS=128
...
...
@@ -189,12 +189,11 @@ def matmul(M,
# Finally, store the dequantized data to shared memory.
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
...
...
@@ -111,7 +116,7 @@ def matmul(M,
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
...
...
@@ -136,7 +141,7 @@ def matmul(M,
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte=8//num_bits
storage_dtype="uint8"
storage_dtype=T.uint8
QK=K//num_elems_per_byte
Block_QK=block_K//num_elems_per_byte
A_shape=(M,K)
...
...
@@ -150,6 +155,7 @@ def matmul(M,
assertK%(block_K*split)==0
fromtilelang.quantizeimportget_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info=get_mxfp_intrin_group(
out_dtype=in_dtype,
...
...
@@ -164,7 +170,7 @@ def matmul(M,
assertfunc_nameisnotNone,"mxfp_intrin_info is not found"
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
...
...
@@ -175,12 +181,12 @@ def matmul(M,
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assertin_dtypein["fp4"]
assertout_dtypein["bfloat16"]
assertout_dtypein[T.bfloat16]
# Some variables for dequantization in each thread
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- Only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper.
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
...
...
@@ -111,7 +116,7 @@ def matmul(M,
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
...
...
@@ -136,7 +141,7 @@ def matmul(M,
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte=8//num_bits
storage_dtype="uint8"
storage_dtype=T.uint8
QK=K//num_elems_per_byte
Block_QK=block_K//num_elems_per_byte
A_shape=(M,K)
...
...
@@ -150,6 +155,7 @@ def matmul(M,
assertK%(block_K*split)==0
fromtilelang.quantizeimportget_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info=get_mxfp_intrin_group(
out_dtype=in_dtype,
...
...
@@ -164,7 +170,7 @@ def matmul(M,
assertfunc_nameisnotNone,"mxfp_intrin_info is not found"
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
...
...
@@ -175,12 +181,12 @@ def matmul(M,
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assertin_dtypein["fp4"]
assertout_dtypein["bfloat16"]
assertout_dtypein[T.bfloat16]
# Some variables for dequantization in each thread
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- Only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper.
Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype.
...
...
@@ -82,8 +83,8 @@ def matmul(M,
topk (int): number of experts selected per token.
E (int): number of experts.
padding_M (int): padded number of tokens after grouping and block alignment.
in_dtype (str): element type of A (e.g., "bfloat16").
out_dtype (str): output tensor element type (e.g., "bfloat16").
in_dtype (str): element type of A (e.g., T.bfloat16).
out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
...
...
@@ -145,12 +147,12 @@ def matmul(M,
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assertin_dtypein["fp4"]
assertout_dtypein["bfloat16"]
assertout_dtypein[T.bfloat16]
# Some variables for dequantization in each thread
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI=block_I
NI=tilelang.cdiv(topk,block_I)
D=dim
D_tail=tail_dim
ifhead_kv>64:
asserthead_kv%64==0,"head_kv should be a multiple of 64"
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI=block_I
NI=tilelang.cdiv(topk,block_I)
D=dim
D_tail=tail_dim
ifhead_kv>64:
asserthead_kv%64==0,"head_kv should be a multiple of 64"