Unverified Commit 8554cb01 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Enhancement] Add a MXFP4 grouped GEMM example for FusedMoE (#811)



* [Enhancement] Enhance dequantization examples and utilities

- Added a new example for grouped matrix multiplication with experts in `example_dequant_groupgemm_bf16_mxfp4_hopper.py`.
- Improved dequantization logic in existing examples by replacing nested loops with vectorized operations for better performance.
- Updated `torch_convert_bit_twiddling` function in `utils.py` to utilize parallel processing, enhancing efficiency and clarity in the conversion process.
Co-authored-by: default avatarZhengju Tang <97930865+tzj-fxz@users.noreply.github.com>

* fix typos in docstrings

* remove redundant code

* [Format] Unreproducible debug with T.print

* [BugFix] Correct dtype in ref dequantize; larger data distribution

* [Format]

* [Refactor] Clean up and optimize example_dequant_groupgemm_bf16_mxfp4_hopper.py and utils.py

- Removed unnecessary cache disabling and manual seed setting in the example.
- Simplified nested loops into parallelized operations for better readability and performance.
- Updated the assertion function in utils.py to print detailed error messages.
- Adjusted tensor sizes in examples

* [Refactor] Update import path in example_dequant_gemm_fine_grained.py

- Changed the import statement for `_tir_packed_to_unsigned_convert` from `bitblas.quantization` to `tilelang.quantize` to reflect the new module structure.

* lint

* rename and add test

* lint

* [Feature] Enhance autotuning and configuration generation in example_dequant_groupedgemm_bf16_mxfp4_hopper.py

- Added a new function `get_configs()` to generate hyperparameter configurations for tuning.
- Updated the `matmul` function to utilize autotuning with the new configurations.
- Improve kernel performance via vectorization and threadblock swizzle.
- Enhanced the main function to support the new autotuning inputs and updated parameters for better performance.

* lint

* fix typo

* fix typo and lint

* make ci format check happy

* fix ci

---------
Co-authored-by: default avatarZhengju Tang <97930865+tzj-fxz@users.noreply.github.com>
Co-authored-by: default avatartzj-fxz <tzjfxz@gmail.com>
parent e4a346fe
...@@ -389,9 +389,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): ...@@ -389,9 +389,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
""" """
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]): B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -414,9 +412,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): ...@@ -414,9 +412,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
""" """
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]): B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -440,9 +436,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): ...@@ -440,9 +436,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
""" """
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert(qB) B = torch_convert(qB)
for i in range(B.shape[0]): B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -470,9 +464,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): ...@@ -470,9 +464,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
""" """
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert(qB) B = torch_convert(qB)
for i in range(B.shape[0]): B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
......
...@@ -23,7 +23,7 @@ def matmul( ...@@ -23,7 +23,7 @@ def matmul(
threads, threads,
num_bits=4, num_bits=4,
): ):
from bitblas.quantization import _tir_packed_to_unsigned_convert from tilelang.quantize import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "int8" storage_dtype = "int8"
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
......
import tilelang
import tilelang.language as T
from tilelang.quantize import _tir_u8_to_f4_to_bf16
from tilelang import tvm as tvm
from tvm import DataType
import torch
from utils import torch_convert_bit_twiddling, assert_similar
from tilelang.autotuner import set_autotune_inputs
def get_configs():
"""
Generate a list of hyperparameter configuration dictionaries for tuning.
Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
'num_stages', 'threads', and 'split'. The function returns the Cartesian
product of the parameter value lists:
- block_M, block_N, block_K: tiling sizes
- num_stages: pipeline stages
- threads: thread counts
- split: K-splitting factor
Returns:
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools
iter_params = dict(
block_M=[128],
block_N=[64, 128, 256],
block_K=[128],
num_stages=[0, 1, 2],
threads=[128, 256, 512],
split=[1],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[-1])
def matmul(M,
N,
K,
topk,
E,
padding_M,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=128,
block_N=256,
block_K=128,
num_stages=2,
threads=256,
split=1):
"""
Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype` and shape (M, K).
- B: packed quantized matrix for all experts, stored as uint8 with `num_bits` bits per element, shape (E, N, QK), where QK = K / (8/num_bits).
- Scale: per-expert, per-block scale/exponent information for dequantizing B, shape (E, N, K // scale_size).
- Bias: per-expert, per-output bias, shape (E, N).
- topk_weights: router weights for the top-k experts for each token, shape (M, topk).
- sorted_token_ids: flattened and padded tensor of token indices, shape (padding_M,).
- expert_ids: expert id for each token in the padded batch, shape (padding_M // block_M,).
- C: output tensor, shape (M, topk, N).
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is (M, topk, N)). K must be divisible by (block_K * split).
topk (int): number of experts selected per token.
E (int): number of experts.
padding_M (int): padded number of tokens after grouping and block alignment.
in_dtype (str): element type of A (e.g., "bfloat16").
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the grouped, pipelined GEMM that:
- loads tiled blocks of A and packed B for each expert to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- applies per-token topk weights and bias,
- writes the final (M, topk, N) block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
Bias_shared_shape = (block_N)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
use_twiddling=True,
)
import_source = mxfp_intrin_info["c_source"]
func_name = mxfp_intrin_info["func_name"]
assert import_source is not None, "mxfp_intrin_info is not found"
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
# the dequant part is the same as in dequant_gemm
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
- Loads packed FP4 elements from B_shared into per-thread local registers.
- Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values.
- Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two).
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits
local_compress_size = local_size // num_elems_per_byte
@T.macro
def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, k):
# import fast_dequantize plugin
"""
Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16
in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4,
applying per-block scale factors from Scale.
This routine is a tiled, thread-parallel helper that:
- Imports and calls an external dequantization function (via `import_source`/`func_name`)
to expand compressed uint8-packed FP4 values into BF16 fragments in-thread.
- Loads the corresponding per-block scale entry, interprets it as an exponent bias
(applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor.
- Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place.
Parameters:
- B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout).
- B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values.
- Scale_shared: per-block scale tensor; entries are interpreted such that the multiplicative scale
= 2^(Scale - 127).
- k: block index along the K dimension used to select the appropriate Scale entries.
Side effects:
- Mutates B_dequantize_shared in shared memory.
- Calls an external intrinsic function (must be provided by the environment via `import_source`
and `func_name`) to perform the low-level unpacking/dequantization.
"""
T.import_source(import_source)
tx = T.get_thread_binding()
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
Scale_local_thread = T.alloc_local((1,), storage_dtype)
Scale_local_thread_exponent = T.alloc_local((1,), out_dtype)
for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register.
# Prepare for dequant.
index_base = i * threads * local_compress_size + tx * local_compress_size
for v in T.vectorized(0, local_compress_size):
index = index_base + v
B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK]
index_scale = index_base // (scale_size // num_elems_per_byte)
si = index_scale // (block_K // scale_size)
sj = index_scale % (block_K // scale_size)
Scale_local_thread[0] = Scale_shared[si, k * block_K // scale_size + sj]
Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
# Then, dequant.
T.call_extern(
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
dtype=out_dtype,
)
# Finally, store the dequantized data to shared memory.
for v in T.Parallel(local_size):
B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0]
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_shared[
i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4
@T.prim_func
def main(
A: T.Tensor((M, K), in_dtype),
B: T.Tensor((E, N, QK), storage_dtype),
Scale: T.Tensor((E, N, K // scale_size), storage_dtype),
Bias: T.Tensor((E, N), out_dtype),
# Add fusedmoe tensors
topk_weights: T.Tensor((M * topk), out_dtype),
sorted_token_ids: T.Tensor((padding_M), "int32"),
expert_ids: T.Tensor((padding_M // block_M), "int32"),
C: T.Tensor((M, topk, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
topk_weights_shared = T.alloc_shared((block_M), out_dtype)
sorted_token_ids_shared = T.alloc_shared((block_M), "int32")
expert_id = T.alloc_local((1), "int32") # the expert id for the current block
# To use 1D TMA, the last dim of Scale_shared must have stride=1
# May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.use_swizzle(10)
if threads == 512:
T.disable_warp_group_reg_alloc()
T.copy(sorted_token_ids[by * block_M:(by + 1) * block_M], sorted_token_ids_shared)
expert_id[0] = expert_ids[by]
# Get the topk weights of each token in the current block
for i in T.Parallel(block_M):
if sorted_token_ids_shared[i] != -1:
topk_weights_shared[i] = topk_weights[sorted_token_ids_shared[i]]
# Get bias and scale based on the expert id
if with_bias:
T.copy(Bias[expert_id[0], bx * block_N:(bx + 1) * block_N], Bias_shared)
else:
T.clear(Bias_shared)
T.copy(Scale[expert_id[0], bx * block_N:(bx + 1) * block_N, :], Scale_shared)
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = Bias_shared[j]
tx = T.get_thread_binding()
for k in T.Pipelined(K // block_K, num_stages=num_stages):
# Each thread copies 4 bytes, local size is 16
for copy_i in T.serial(block_M * block_K // threads // 16):
base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_K] != -1:
for copy_j in T.vectorized(16):
A_shared[base // block_K, base % block_K +
copy_j] = A[sorted_token_ids_shared[base // block_K] // topk,
k * block_K + base % block_K + copy_j]
T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared,
k)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = C_local[i, j] * topk_weights_shared[i]
T.copy(C_local, C_shared)
for copy_i in T.serial(block_M * block_N // threads // 16):
base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_N] != -1:
for copy_j in T.vectorized(16):
C[sorted_token_ids_shared[base // block_N] // topk,
sorted_token_ids_shared[base // block_N] % topk, bx * block_N +
base % block_N + copy_j] = C_shared[base // block_N,
base % block_N + copy_j]
return main
def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256):
dtypeC = "bfloat16"
M, K = A.shape
E, N, QK = qB.shape
topk = topk_weights.shape[0] // M
scale_size = K // Scale.shape[2]
assert scale_size == 32 # MXFP4
# Initialize output tensor
C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device='cuda')
# Iterate over sorted_token_ids
for idx in range(len(sorted_token_ids)): # padding_M
token_id = sorted_token_ids[idx]
if token_id == -1:
continue
expert_id = expert_ids[idx // block_M]
topk_idx = token_id % topk
# Get the token embedding
token_embedding = A[token_id // topk]
# Dequantize the expert weights
B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K)
B *= 2**(
Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(
torch.bfloat16))
# Compute the output for this token-expert pair
# token_embedding @ B.T + bias
output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(
torch.bfloat16)) + Bias[expert_id]
output = output.to(torch.__getattribute__(dtypeC))
# Apply the topk weight
weight = topk_weights[token_id]
output = output * weight
# Store the result
C[token_id // topk, topk_idx] = output
return C
def get_data(m, n, k, qk, scale_size, topk, E, block_M):
A = torch.empty(m, k, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
qB = torch.randint(
0, 256, (E, n, qk), dtype=torch.uint8,
device='cuda') # Quantized weight tensor for E experts.
Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device='cuda')
Bias = torch.empty(E, n, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
weights = torch.empty(m, E, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
# topk_weights: Router weights for the top-k experts for each token.
# Shape: (m, topk)
# tokens_experts: A flattened tensor of expert assignments for each token.
# For each of m tokens, topk unique experts are chosen. Shape: (m * topk,)
topk_weights, tokens_experts = torch.topk(weights, topk, dim=-1)
tokens_experts = tokens_experts.reshape(m * topk)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.reshape(m * topk)
sorted_expert_vals, sorted_indices = torch.sort(tokens_experts, stable=True)
sorted_token_ids = sorted_indices
unique_expert_ids, counts = torch.unique_consecutive(sorted_expert_vals, return_counts=True)
expert_ids = []
padded_token_ids = []
start = 0
for eid, cnt in zip(unique_expert_ids.tolist(), counts.tolist()):
end = start + cnt
group_token_ids = sorted_token_ids[start:end]
pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt
if pad_len > 0:
# -1 for padding (`M` instead in vLLM moe_align_block_size())
group_token_ids = torch.cat([
group_token_ids,
torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device='cuda')
])
padded_token_ids.append(group_token_ids)
expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M))
start = end
# sorted_token_ids: The final flattened and padded tensor of token indices.
sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,)
# expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`.
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,)
padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding
print(f'{sorted_token_ids=}')
print(f'{expert_ids=}')
return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, topk=4, E=32):
# Tunable parameters
block_M, block_N, block_K = 128, 256, 128 # noqa: F841
num_stages = 1 # noqa: F841
threads = 512 # noqa: F841
split = 1 # noqa: F841
total_flops = 2 * m * n * k * topk
num_bits = 4
num_elems_per_byte = 8 // num_bits
qk = k // num_elems_per_byte
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(
m, n, k, qk, scale_size, topk, E, block_M)
with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]):
# Autotune with inputs manually composed
kernel = matmul(
m,
n,
k,
topk,
E,
padding_M,
"bfloat16",
"bfloat16",
"float32",
num_bits=num_bits,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias,
)
print(f'Best config: {kernel.config}')
output = kernel(
A,
qB,
Scale,
Bias,
topk_weights,
sorted_token_ids,
expert_ids,
)
print('Tilelang kernel run finished.')
ref_output = ref_moe(
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids,
block_M=block_M) # Maybe a little bit slow...
latency = tilelang.profiler.do_bench(
lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
diff = (output - ref_output).abs()
max_val = diff.max()
max_idx = diff.argmax()
print(f"max abs diff: {max_val} at index: {max_idx}")
assert_similar(
output, ref_output, name="output",
eps=1e-5) # We care about the similarity rather than abs. difference
print("All checks pass. ✅")
if __name__ == "__main__":
M, N, K = 16384, 5760, 2944 # From gpt-oss-20b MoE's first gemm
scale_size = 32
topk = 4 # experts activated for each token
E = 32 # number of experts
main(M, N, K, scale_size, fast_dequant=True, with_bias=True, topk=topk, E=E)
...@@ -4,6 +4,7 @@ import example_dequant_gemv_fp16xint4 ...@@ -4,6 +4,7 @@ import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper import example_dequant_gemm_bf16_mxfp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper_tma import example_dequant_gemm_bf16_mxfp4_hopper_tma
import example_dequant_groupedgemm_bf16_mxfp4_hopper
import example_dequant_gemm_w4a8 import example_dequant_gemm_w4a8
...@@ -31,6 +32,13 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma(): ...@@ -31,6 +32,13 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_groupedgemm_bf16_mxfp4_hopper():
example_dequant_groupedgemm_bf16_mxfp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_w4a8(): def test_example_dequant_gemm_w4a8():
example_dequant_gemm_w4a8.main() example_dequant_gemm_w4a8.main()
......
...@@ -3,8 +3,6 @@ import torch ...@@ -3,8 +3,6 @@ import torch
def torch_convert_bit_twiddling(tensor): def torch_convert_bit_twiddling(tensor):
""" """
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.
This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`. This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.
Parameters: Parameters:
...@@ -16,38 +14,46 @@ def torch_convert_bit_twiddling(tensor): ...@@ -16,38 +14,46 @@ def torch_convert_bit_twiddling(tensor):
Raises: Raises:
AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`. AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`.
""" """
assert tensor.dim() == 2 and tensor.dtype == torch.uint8
N, K = tensor.shape
assert K % 2 == 0, "Number of columns must be even"
def _convert(val0, val1, pos) -> torch.bfloat16: # Combine pairs of uint8 values into uint32 for safe bitwise ops on CUDA
assert val0.dtype == torch.uint8 val0 = tensor[:, 0::2].to(torch.int32)
assert val1.dtype == torch.uint8 val1 = tensor[:, 1::2].to(torch.int32)
val0 = val0.view(torch.uint8) val_concat = (val0 << 8) | val1 # (N, K//2), uint32
val1 = val1.view(torch.uint8)
val_concat = (val0.item() << 8) | val1.item()
mask = 0b1000000111000000
if pos == 0:
bf16 = val_concat & mask
elif pos == 1:
bf16 = (val_concat << 3) & mask
elif pos == 2:
bf16 = (val_concat << 6) & mask
elif pos == 3:
mask1 = 0b1000000000000000
mask2 = 0b0000000110000000
mask3 = 0b0000000001000000
bf16 = ((val_concat << 1) & mask1) | ((val_concat >> 3) & mask2) | (
(val_concat >> 7) & mask3)
bf16_new = torch.tensor([bf16], dtype=torch.uint16, device=val0.device).view(torch.bfloat16)
# Add bias for change from fp4 to bf16
bf16_new = bf16_new.item() * (2**126)
return bf16_new
N = tensor.shape[0] # Expand to match output shape where each pair generates 4 values
K = tensor.shape[1] val_concat_expanded = val_concat.repeat_interleave(4, dim=1) # (N, K//2*4)
new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
for i in range(new_tensor.shape[0]): # Positional encoding for bit-twiddling logic
for j in range(new_tensor.shape[1]): pos = torch.arange(K * 2, device=tensor.device) % 4 # (K*2,)
new_tensor[i][j] = _convert(tensor[i][j // 4 * 2], tensor[i][j // 4 * 2 + 1], j % 4)
return new_tensor # Bit masks for decoding (as uint32 for CUDA compatibility)
mask = 0b1000000111000000
mask1 = 0b1000000000000000
mask2 = 0b0000000110000000
mask3 = 0b0000000001000000
# Calculate results for all 4 positions in parallel
res0 = val_concat_expanded & mask
res1 = (val_concat_expanded << 3) & mask
res2 = (val_concat_expanded << 6) & mask
res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | (
(val_concat_expanded >> 7) & mask3)
# Select the correct result based on position
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1,
torch.where(pos == 2, res2, res3)))
# Convert to uint16 for .view(torch.bfloat16)
bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16)
bf16_bf16 = bf16_uint16.view(torch.bfloat16)
# Avoid integer overflow by using a float32 multiplier for the exponent scaling
bf16_new = bf16_bf16 * (2.0**126)
return bf16_new
def torch_convert(tensor, scale_size=None, Scale=None): def torch_convert(tensor, scale_size=None, Scale=None):
...@@ -106,3 +112,41 @@ def print_bit(name, val): ...@@ -106,3 +112,41 @@ def print_bit(name, val):
val_cpu = val.cpu().item() val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}' binary_repr = f'{val_cpu:032b}'
print(name, binary_repr) print(name, binary_repr)
def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")
def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch')
if raise_assert:
raise AssertionError
if not torch.isclose(
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = (1. - sim).item()
print(f'{diff=}')
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff=}')
if raise_assert:
raise AssertionError
...@@ -331,13 +331,13 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, ...@@ -331,13 +331,13 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
def sync_threads(): def sync_threads():
"""Synchronize all threads in a warp. """Synchronize all threads in a block.
""" """
return tir.op.tvm_storage_sync("shared") return tir.op.tvm_storage_sync("shared")
def sync_global(): def sync_global():
"""Synchronize all threads in a block. """Synchronize all threads in the entire grid.
""" """
tx, ty, tz = get_thread_bindings() tx, ty, tz = get_thread_bindings()
ex, ey, ez = get_block_extents() ex, ey, ez = get_block_extents()
......
...@@ -5,6 +5,7 @@ from .quantization import ( ...@@ -5,6 +5,7 @@ from .quantization import (
_tir_packed_to_fp4_to_f16, # noqa: F401 _tir_packed_to_fp4_to_f16, # noqa: F401
_tir_u8_to_f8_e4m3_to_f16, # noqa: F401 _tir_u8_to_f8_e4m3_to_f16, # noqa: F401
_tir_packed_to_unsigned_convert_with_zeros, # noqa: F401 _tir_packed_to_unsigned_convert_with_zeros, # noqa: F401
_tir_u8_to_f4_to_bf16, # noqa: F401
) )
from .utils import ( from .utils import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment