Unverified Commit 8554cb01 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Enhancement] Add a MXFP4 grouped GEMM example for FusedMoE (#811)



* [Enhancement] Enhance dequantization examples and utilities

- Added a new example for grouped matrix multiplication with experts in `example_dequant_groupgemm_bf16_mxfp4_hopper.py`.
- Improved dequantization logic in existing examples by replacing nested loops with vectorized operations for better performance.
- Updated `torch_convert_bit_twiddling` function in `utils.py` to utilize parallel processing, enhancing efficiency and clarity in the conversion process.
Co-authored-by: default avatarZhengju Tang <97930865+tzj-fxz@users.noreply.github.com>

* fix typos in docstrings

* remove redundant code

* [Format] Unreproducible debug with T.print

* [BugFix] Correct dtype in ref dequantize; larger data distribution

* [Format]

* [Refactor] Clean up and optimize example_dequant_groupgemm_bf16_mxfp4_hopper.py and utils.py

- Removed unnecessary cache disabling and manual seed setting in the example.
- Simplified nested loops into parallelized operations for better readability and performance.
- Updated the assertion function in utils.py to print detailed error messages.
- Adjusted tensor sizes in examples

* [Refactor] Update import path in example_dequant_gemm_fine_grained.py

- Changed the import statement for `_tir_packed_to_unsigned_convert` from `bitblas.quantization` to `tilelang.quantize` to reflect the new module structure.

* lint

* rename and add test

* lint

* [Feature] Enhance autotuning and configuration generation in example_dequant_groupedgemm_bf16_mxfp4_hopper.py

- Added a new function `get_configs()` to generate hyperparameter configurations for tuning.
- Updated the `matmul` function to utilize autotuning with the new configurations.
- Improve kernel performance via vectorization and threadblock swizzle.
- Enhanced the main function to support the new autotuning inputs and updated parameters for better performance.

* lint

* fix typo

* fix typo and lint

* make ci format check happy

* fix ci

---------
Co-authored-by: default avatarZhengju Tang <97930865+tzj-fxz@users.noreply.github.com>
Co-authored-by: default avatartzj-fxz <tzjfxz@gmail.com>
parent e4a346fe
......@@ -389,9 +389,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -414,9 +412,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -440,9 +436,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -470,9 +464,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
......
......@@ -23,7 +23,7 @@ def matmul(
threads,
num_bits=4,
):
from bitblas.quantization import _tir_packed_to_unsigned_convert
from tilelang.quantize import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
......
......@@ -4,6 +4,7 @@ import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper_tma
import example_dequant_groupedgemm_bf16_mxfp4_hopper
import example_dequant_gemm_w4a8
......@@ -31,6 +32,13 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma():
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_groupedgemm_bf16_mxfp4_hopper():
example_dequant_groupedgemm_bf16_mxfp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_w4a8():
example_dequant_gemm_w4a8.main()
......
......@@ -3,8 +3,6 @@ import torch
def torch_convert_bit_twiddling(tensor):
"""
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.
This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.
Parameters:
......@@ -16,38 +14,46 @@ def torch_convert_bit_twiddling(tensor):
Raises:
AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`.
"""
assert tensor.dim() == 2 and tensor.dtype == torch.uint8
N, K = tensor.shape
assert K % 2 == 0, "Number of columns must be even"
def _convert(val0, val1, pos) -> torch.bfloat16:
assert val0.dtype == torch.uint8
assert val1.dtype == torch.uint8
val0 = val0.view(torch.uint8)
val1 = val1.view(torch.uint8)
val_concat = (val0.item() << 8) | val1.item()
mask = 0b1000000111000000
if pos == 0:
bf16 = val_concat & mask
elif pos == 1:
bf16 = (val_concat << 3) & mask
elif pos == 2:
bf16 = (val_concat << 6) & mask
elif pos == 3:
mask1 = 0b1000000000000000
mask2 = 0b0000000110000000
mask3 = 0b0000000001000000
bf16 = ((val_concat << 1) & mask1) | ((val_concat >> 3) & mask2) | (
(val_concat >> 7) & mask3)
bf16_new = torch.tensor([bf16], dtype=torch.uint16, device=val0.device).view(torch.bfloat16)
# Add bias for change from fp4 to bf16
bf16_new = bf16_new.item() * (2**126)
return bf16_new
# Combine pairs of uint8 values into uint32 for safe bitwise ops on CUDA
val0 = tensor[:, 0::2].to(torch.int32)
val1 = tensor[:, 1::2].to(torch.int32)
val_concat = (val0 << 8) | val1 # (N, K//2), uint32
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 4 * 2], tensor[i][j // 4 * 2 + 1], j % 4)
return new_tensor
# Expand to match output shape where each pair generates 4 values
val_concat_expanded = val_concat.repeat_interleave(4, dim=1) # (N, K//2*4)
# Positional encoding for bit-twiddling logic
pos = torch.arange(K * 2, device=tensor.device) % 4 # (K*2,)
# Bit masks for decoding (as uint32 for CUDA compatibility)
mask = 0b1000000111000000
mask1 = 0b1000000000000000
mask2 = 0b0000000110000000
mask3 = 0b0000000001000000
# Calculate results for all 4 positions in parallel
res0 = val_concat_expanded & mask
res1 = (val_concat_expanded << 3) & mask
res2 = (val_concat_expanded << 6) & mask
res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | (
(val_concat_expanded >> 7) & mask3)
# Select the correct result based on position
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1,
torch.where(pos == 2, res2, res3)))
# Convert to uint16 for .view(torch.bfloat16)
bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16)
bf16_bf16 = bf16_uint16.view(torch.bfloat16)
# Avoid integer overflow by using a float32 multiplier for the exponent scaling
bf16_new = bf16_bf16 * (2.0**126)
return bf16_new
def torch_convert(tensor, scale_size=None, Scale=None):
......@@ -106,3 +112,41 @@ def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")
def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch')
if raise_assert:
raise AssertionError
if not torch.isclose(
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = (1. - sim).item()
print(f'{diff=}')
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff=}')
if raise_assert:
raise AssertionError
......@@ -331,13 +331,13 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
def sync_threads():
"""Synchronize all threads in a warp.
"""Synchronize all threads in a block.
"""
return tir.op.tvm_storage_sync("shared")
def sync_global():
"""Synchronize all threads in a block.
"""Synchronize all threads in the entire grid.
"""
tx, ty, tz = get_thread_bindings()
ex, ey, ez = get_block_extents()
......
......@@ -5,6 +5,7 @@ from .quantization import (
_tir_packed_to_fp4_to_f16, # noqa: F401
_tir_u8_to_f8_e4m3_to_f16, # noqa: F401
_tir_packed_to_unsigned_convert_with_zeros, # noqa: F401
_tir_u8_to_f4_to_bf16, # noqa: F401
)
from .utils import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment