Unverified Commit 796b3bbe authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[MXFP4] Fix bugs and optimize exponential operation (#750)



* [MXFP4] Fix bugs
- Optimize exp2 with shift operation to boost performance
- Fix bug of simple dequantization function call
- Fix bug of scaling factor with bias

* [Lint]

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent e8357626
......@@ -40,8 +40,8 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
......@@ -218,7 +218,7 @@ def matmul(M,
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
Scale_local_thread = T.alloc_local((1,), storage_dtype)
Scale_local_thread_exponent = T.alloc_local((1,), "float32")
Scale_local_thread_exponent = T.alloc_local((1,), out_dtype)
for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register.
......@@ -231,8 +231,7 @@ def matmul(M,
si = index_scale // (block_K // scale_size)
sj = index_scale % (block_K // scale_size)
Scale_local_thread[0] = Scale[bx * block_N + si, k * block_K // scale_size + sj]
Scale_local_thread_exponent[0] = T.exp2(
T.cast(Scale_local_thread[0] - 127, "float"))
Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
# Then, dequant.
T.call_extern(
......@@ -288,7 +287,7 @@ def matmul(M,
- Mutates B_dequantize_shared by storing the dequantized BF16 fragment.
"""
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
bx = T.get_block_binding(0)
T.copy(B_shared, B_local)
......@@ -300,8 +299,9 @@ def matmul(M,
Scale[
bx * block_N + i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
dtype=in_dtype,
)
dtype=out_dtype,
) * T.shift_left(
1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4
......@@ -374,7 +374,7 @@ def ref_program_twiddling(A, qB, Scale):
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127))
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -400,7 +400,7 @@ def ref_program_simple(A, qB, Scale):
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127))
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -427,7 +427,15 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
if tune:
kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size)
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant)
else:
kernel = matmul(
m,
......@@ -443,7 +451,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
block_K=128,
num_stages=2,
threads=256,
split=1)
split=1,
fast_dequant=fast_dequant)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment