Unverified Commit fd199a4a authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[MXFP4] Add bias to MXFP4 GEMM kernel (#753)

* [MXFP4] Add bias to gemm kernel

* [Lint]

* [Lint] Rename "bias" to "Bias"
parent cf7be057
......@@ -90,6 +90,7 @@ def matmul(M,
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=256,
block_N=128,
block_K=128,
......@@ -120,7 +121,8 @@ def matmul(M,
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
......@@ -139,9 +141,11 @@ def matmul(M,
Block_QK = block_K // num_elems_per_byte
A_shape = (M, K)
B_shape = (N, QK)
Bias_shape = (M, N)
Scale_shape = (N, K // scale_size)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
Bias_shared_shape = (block_M, block_N)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
......@@ -311,6 +315,7 @@ def matmul(M,
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype),
):
"""
......@@ -328,7 +333,7 @@ def matmul(M,
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
......@@ -337,10 +342,22 @@ def matmul(M,
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
if with_bias:
T.annotate_layout({
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
})
if threads == 512:
T.disable_warp_group_reg_alloc()
T.clear(C_local)
if with_bias:
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
Bias_shared)
T.copy(Bias_shared, C_local)
else:
T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
......@@ -356,7 +373,7 @@ def matmul(M,
return main
def ref_program_twiddling(A, qB, Scale):
def ref_program_twiddling(A, qB, Scale, Bias=None):
"""
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
......@@ -380,7 +397,32 @@ def ref_program_twiddling(A, qB, Scale):
return C
def ref_program_simple(A, qB, Scale):
def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
"""
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Parameters:
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
Bias (torch.Tensor): Bias tensor with shape (M, N).
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
def ref_program_simple(A, qB, Scale, Bias=None):
"""
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
......@@ -406,7 +448,37 @@ def ref_program_simple(A, qB, Scale):
return C
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
def ref_program_simple_with_bias(A, qB, Scale, Bias):
"""
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Parameters:
Returns:
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
- qB: Quantized representation of B accepted by `torch_convert`.
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
- Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul).
Returns:
- 2D bfloat16 tensor C containing the matrix product A · B^T.
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False):
"""
Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS.
......@@ -435,7 +507,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant)
fast_dequant=fast_dequant,
with_bias=with_bias)
else:
kernel = matmul(
m,
......@@ -452,14 +525,21 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
num_stages=2,
threads=256,
split=1,
fast_dequant=fast_dequant)
fast_dequant=fast_dequant,
with_bias=with_bias)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
if fast_dequant:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
if with_bias:
profiler.assert_allclose(ref_program_twiddling_with_bias, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01)
if with_bias:
profiler.assert_allclose(ref_program_simple_with_bias, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
......@@ -469,5 +549,7 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
if __name__ == "__main__":
M, N, K = 256, 256, 256
scale_size = 32
main(M, N, K, scale_size, fast_dequant=True)
main(M, N, K, scale_size, fast_dequant=False)
main(M, N, K, scale_size, fast_dequant=True, with_bias=True)
main(M, N, K, scale_size, fast_dequant=False, with_bias=True)
main(M, N, K, scale_size, fast_dequant=True, with_bias=False)
main(M, N, K, scale_size, fast_dequant=False, with_bias=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment