Unverified Commit 796b3bbe authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[MXFP4] Fix bugs and optimize exponential operation (#750)



* [MXFP4] Fix bugs
- Optimize exp2 with shift operation to boost performance
- Fix bug of simple dequantization function call
- Fix bug of scaling factor with bias

* [Lint]

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent e8357626
...@@ -40,8 +40,8 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -40,8 +40,8 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16") e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8 # Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits # To handle the overflow, we may use the min function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16") m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16", val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
...@@ -218,7 +218,7 @@ def matmul(M, ...@@ -218,7 +218,7 @@ def matmul(M,
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
Scale_local_thread = T.alloc_local((1,), storage_dtype) Scale_local_thread = T.alloc_local((1,), storage_dtype)
Scale_local_thread_exponent = T.alloc_local((1,), "float32") Scale_local_thread_exponent = T.alloc_local((1,), out_dtype)
for i in T.serial(0, block_N * block_K // threads // local_size): for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register. # First, load data from share memory to register.
...@@ -231,8 +231,7 @@ def matmul(M, ...@@ -231,8 +231,7 @@ def matmul(M,
si = index_scale // (block_K // scale_size) si = index_scale // (block_K // scale_size)
sj = index_scale % (block_K // scale_size) sj = index_scale % (block_K // scale_size)
Scale_local_thread[0] = Scale[bx * block_N + si, k * block_K // scale_size + sj] Scale_local_thread[0] = Scale[bx * block_N + si, k * block_K // scale_size + sj]
Scale_local_thread_exponent[0] = T.exp2( Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
T.cast(Scale_local_thread[0] - 127, "float"))
# Then, dequant. # Then, dequant.
T.call_extern( T.call_extern(
...@@ -288,7 +287,7 @@ def matmul(M, ...@@ -288,7 +287,7 @@ def matmul(M,
- Mutates B_dequantize_shared by storing the dequantized BF16 fragment. - Mutates B_dequantize_shared by storing the dequantized BF16 fragment.
""" """
B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
bx = T.get_block_binding(0) bx = T.get_block_binding(0)
T.copy(B_shared, B_local) T.copy(B_shared, B_local)
...@@ -300,8 +299,9 @@ def matmul(M, ...@@ -300,8 +299,9 @@ def matmul(M,
Scale[ Scale[
bx * block_N + i, k * block_K // scale_size + j // bx * block_N + i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8 scale_size], # Scale is the exponential part, within the representation of uint8
dtype=in_dtype, dtype=out_dtype,
) ) * T.shift_left(
1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared) T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4 return simple_dequant_bf16_fp4
...@@ -374,7 +374,7 @@ def ref_program_twiddling(A, qB, Scale): ...@@ -374,7 +374,7 @@ def ref_program_twiddling(A, qB, Scale):
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127)) B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -400,7 +400,7 @@ def ref_program_simple(A, qB, Scale): ...@@ -400,7 +400,7 @@ def ref_program_simple(A, qB, Scale):
B = torch_convert(qB) B = torch_convert(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127)) B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -427,7 +427,15 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False): ...@@ -427,7 +427,15 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
if tune: if tune:
kernel = matmul( kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size) m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant)
else: else:
kernel = matmul( kernel = matmul(
m, m,
...@@ -443,7 +451,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False): ...@@ -443,7 +451,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
block_K=128, block_K=128,
num_stages=2, num_stages=2,
threads=256, threads=256,
split=1) split=1,
fast_dequant=fast_dequant)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment