Unverified Commit fd199a4a authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[MXFP4] Add bias to MXFP4 GEMM kernel (#753)

* [MXFP4] Add bias to gemm kernel

* [Lint]

* [Lint] Rename "bias" to "Bias"
parent cf7be057
...@@ -90,6 +90,7 @@ def matmul(M, ...@@ -90,6 +90,7 @@ def matmul(M,
num_bits=4, num_bits=4,
scale_size=32, scale_size=32,
fast_dequant=True, fast_dequant=True,
with_bias=False,
block_M=256, block_M=256,
block_N=128, block_N=128,
block_K=128, block_K=128,
...@@ -120,7 +121,8 @@ def matmul(M, ...@@ -120,7 +121,8 @@ def matmul(M,
num_stages (int, optional): pipelining stages for K loop (default 2). num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256). threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1). split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns: Returns:
A T.prim_func implementing the tiled, pipelined GEMM that: A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory, - loads tiled blocks of A and packed B to shared memory,
...@@ -139,9 +141,11 @@ def matmul(M, ...@@ -139,9 +141,11 @@ def matmul(M,
Block_QK = block_K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte
A_shape = (M, K) A_shape = (M, K)
B_shape = (N, QK) B_shape = (N, QK)
Bias_shape = (M, N)
Scale_shape = (N, K // scale_size) Scale_shape = (N, K // scale_size)
A_shared_shape = (block_M, block_K) A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK) B_shared_shape = (block_N, Block_QK)
Bias_shared_shape = (block_M, block_N)
B_dequantize_shared_shape = (block_N, block_K) B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0 assert K % (block_K * split) == 0
...@@ -311,6 +315,7 @@ def matmul(M, ...@@ -311,6 +315,7 @@ def matmul(M,
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype), Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
""" """
...@@ -328,7 +333,7 @@ def matmul(M, ...@@ -328,7 +333,7 @@ def matmul(M,
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype)
...@@ -337,10 +342,22 @@ def matmul(M, ...@@ -337,10 +342,22 @@ def matmul(M,
B_shared: tilelang.layout.make_swizzled_layout(B_shared), B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared), C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}) })
if with_bias:
T.annotate_layout({
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
})
if threads == 512: if threads == 512:
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
T.clear(C_local) if with_bias:
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
Bias_shared)
T.copy(Bias_shared, C_local)
else:
T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages): for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
...@@ -356,7 +373,7 @@ def matmul(M, ...@@ -356,7 +373,7 @@ def matmul(M,
return main return main
def ref_program_twiddling(A, qB, Scale): def ref_program_twiddling(A, qB, Scale, Bias=None):
""" """
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results. Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
...@@ -380,7 +397,32 @@ def ref_program_twiddling(A, qB, Scale): ...@@ -380,7 +397,32 @@ def ref_program_twiddling(A, qB, Scale):
return C return C
def ref_program_simple(A, qB, Scale): def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
"""
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Parameters:
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
Bias (torch.Tensor): Bias tensor with shape (M, N).
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
def ref_program_simple(A, qB, Scale, Bias=None):
""" """
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization. Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
...@@ -406,7 +448,37 @@ def ref_program_simple(A, qB, Scale): ...@@ -406,7 +448,37 @@ def ref_program_simple(A, qB, Scale):
return C return C
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False): def ref_program_simple_with_bias(A, qB, Scale, Bias):
"""
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Parameters:
Returns:
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
- qB: Quantized representation of B accepted by `torch_convert`.
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
- Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul).
Returns:
- 2D bfloat16 tensor C containing the matrix product A · B^T.
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False):
""" """
Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS. Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS.
...@@ -435,7 +507,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False): ...@@ -435,7 +507,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
"float32", "float32",
num_bits=4, num_bits=4,
scale_size=scale_size, scale_size=scale_size,
fast_dequant=fast_dequant) fast_dequant=fast_dequant,
with_bias=with_bias)
else: else:
kernel = matmul( kernel = matmul(
m, m,
...@@ -452,14 +525,21 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False): ...@@ -452,14 +525,21 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
num_stages=2, num_stages=2,
threads=256, threads=256,
split=1, split=1,
fast_dequant=fast_dequant) fast_dequant=fast_dequant,
with_bias=with_bias)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
if fast_dequant: if fast_dequant:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) if with_bias:
profiler.assert_allclose(ref_program_twiddling_with_bias, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
else: else:
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01) if with_bias:
profiler.assert_allclose(ref_program_simple_with_bias, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
latency = profiler.do_bench(warmup=500) latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
...@@ -469,5 +549,7 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False): ...@@ -469,5 +549,7 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
if __name__ == "__main__": if __name__ == "__main__":
M, N, K = 256, 256, 256 M, N, K = 256, 256, 256
scale_size = 32 scale_size = 32
main(M, N, K, scale_size, fast_dequant=True) main(M, N, K, scale_size, fast_dequant=True, with_bias=True)
main(M, N, K, scale_size, fast_dequant=False) main(M, N, K, scale_size, fast_dequant=False, with_bias=True)
main(M, N, K, scale_size, fast_dequant=True, with_bias=False)
main(M, N, K, scale_size, fast_dequant=False, with_bias=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment