"docs/source/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "c6c361d80ada8117e926bd24f71f50bb5da9f0b3"
Unverified Commit 569b0127 authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

Low-bit kernels fix and implementation (#704)



* [MXFP4] Dequantize FP4 kernel example, MX scale todo

* [BugFix] Fix the bug of fp4&fp16 exponential bias

* [MXFP4] Add group scale factor for BF16xMXFP4 gemm

* [Lint]

* [Test] Add test script for BF16xMXFP4 gemm

* [Lint]

* [BugFix] Fix the shape of scale tensor

* Update example_dequant_gemm_fp4_hopper.py

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 376ba9eb
...@@ -12,16 +12,18 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: ...@@ -12,16 +12,18 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype:
assert dtype == "float16" assert dtype == "float16"
assert val.dtype == "uint8" assert val.dtype == "uint8"
# e_f4 == 0 -> e_f16 = 0 # e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 # e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14
# s1e2n1 # s1e2m1
mask = tir.const((1 << nbit) - 1, "uint16") mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16") s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16") e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16") e_f16 = e_f4 + tir.const(14, "uint16")
val_f16 = tir.reinterpret( m_f4 = f4 & tir.const(1, "uint16")
"float16", m_f16 = m_f4
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16")) val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")
| m_f16 << tir.const(9, "uint16")).astype("uint16"))
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
return val_f16 return val_f16
...@@ -39,9 +41,11 @@ def torch_convert(tensor): ...@@ -39,9 +41,11 @@ def torch_convert(tensor):
mask = (1 << 4) - 1 mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16) f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3 s = f4 >> 3
e_f4 = f4 & 7 e_f4 = (f4 & 6) >> 1
e_f16 = e_f4 | 8 e_f16 = e_f4 + 14
val_f16 = ((e_f16 | (s << 5)) << 10) & 0xFFFF m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 5)) << 10) | (m_f16 << 9)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16) lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.float16) return lower_16_bits.view(torch.float16)
......
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from tvm import tir
import argparse
import itertools
import torch
tilelang.disable_cache()
torch.manual_seed(0)
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str):
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
return val_bf16
def torch_convert(tensor, scale_size=None, Scale=None):
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
def _convert(val, pos, scale=None):
assert val.dtype == torch.uint8
# val = val.view(torch.int8)
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = (f4 & 6) >> 1
e_f16 = e_f4 + 126
if scale is not None:
e_f16 = min(e_f16 + scale, (1 << 8) - 1)
m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.bfloat16)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
if scale_size is not None:
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2, Scale[i][j // scale_size])
else:
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
@tilelang.jit(out_idx=[-1])
def convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
0, # No scale for test
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
@tilelang.jit(out_idx=[-1])
def convert_scale(N, K, block_N, block_K, in_dtype, num_bits=4, scale_size=32, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
Scale_shape = (N, K // scale_size)
Scale_shared_shape = (block_N, block_K // scale_size)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Scale_shared = T.alloc_shared(Scale_shared_shape, storage_dtype)
Scale_local = T.alloc_fragment(Scale_shared_shape, storage_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
T.copy(Scale[bx * block_N, k * block_K // scale_size], Scale_shared)
T.copy(Scale_shared, Scale_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_local[
i, j //
scale_size], # Scale is the exponential part, within the representation of uint8
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
def test_fp4_bf16_convert_close():
N, K = 256, 256
block_N, block_K = 64, 64
kernel = convert(
N,
K,
block_N,
block_K,
"bfloat16",
)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = kernel(B)
ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Convert Pass")
def test_fp4_bf16_convert_scale_close():
N, K = 256, 256
block_N, block_K = 64, 64
kernel = convert_scale(N, K, block_N, block_K, "bfloat16", scale_size=32)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
Scale = torch.randint(0, 1, (N, K // 32), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = kernel(B, Scale)
ref_out = torch_convert(B, scale_size=32, Scale=Scale)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Convert Scale Pass")
def get_configs():
block_M = [128]
block_N = [128, 256]
block_K = [128]
num_stages = [2]
threads = [256]
splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[4],
'split': c[5]
} for c in _configs]
return configs
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, scale_size=32, tune=False):
@tilelang.jit(out_idx=[-1])
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
Scale_shape = (N, K // scale_size)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
Scale_shared_shape = (block_N, block_K // scale_size)
assert K % (block_K * split) == 0
KK = K // split
@T.prim_func
def main_split(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
SplitC = T.alloc_buffer([
split, (N + block_N - 1) // block_N * block_N,
(M + block_M - 1) // block_M * block_M
], out_dtype)
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split,
threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
Scale_shared = T.alloc_shared(Scale_shared_shape, storage_dtype)
Scale_local = T.alloc_fragment(Scale_shared_shape, storage_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
Scale_shared: tilelang.layout.make_swizzled_layout(Scale_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
T.copy(A[by * block_M, KK * bz + k * block_K], A_shared)
T.copy(B[bx * block_N, (KK * bz + k * block_K) // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
T.copy(Scale[bx * block_N, (KK * bz + k * block_K) // scale_size], Scale_shared)
T.copy(Scale_shared, Scale_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_local[i, j // scale_size],
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
acc = T.alloc_fragment((block_N, block_M), out_dtype)
T.clear(acc)
for k in range(split):
for i, j in T.Parallel(block_N, block_M):
acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j]
T.copy(acc, Ct[bx * block_N, by * block_M])
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
Scale_shared = T.alloc_shared((block_N, block_K // scale_size), storage_dtype)
Scale_local = T.alloc_fragment((block_N, block_K // scale_size), storage_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
Scale_shared: tilelang.layout.make_swizzled_layout(Scale_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
T.copy(Scale[bx * block_N, k * block_K // scale_size], Scale_shared)
T.copy(Scale_shared, Scale_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_local[i, j // scale_size],
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
if split == 1:
return main
else:
return main_split
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
warmup=10,
rep=10)
@tilelang.jit(out_idx=[-1])
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
threads=None,
split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads, split=1):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel
def ref_program(A, qB):
dtypeC = "bfloat16"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def ref_program_scale(A, qB, Scale):
dtypeC = "bfloat16"
B = torch_convert(qB, scale_size=32, Scale=Scale)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def main(m=256, n=256, k=256, scale_size=32, tune=False):
total_flops = 2 * m * n * k
if (not tune):
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler.assert_allclose(ref_program_scale, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_scale, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
def test_convert():
test_fp4_bf16_convert_close()
test_fp4_bf16_convert_scale_close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M')
parser.add_argument('--n', type=int, default=256, help='N')
parser.add_argument('--k', type=int, default=256, help='K')
parser.add_argument(
'--scale_size',
type=int,
default=32,
help='scale size, the exponential part, within the representation of uint8')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
# test_convert()
main(M, N, K, args.scale_size, args.tune)
...@@ -2,6 +2,7 @@ import tilelang.testing ...@@ -2,6 +2,7 @@ import tilelang.testing
import example_dequant_gemv_fp16xint4 import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_mxfp4_hopper
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -15,5 +16,11 @@ def test_example_dequant_gemm_fp4_hopper(): ...@@ -15,5 +16,11 @@ def test_example_dequant_gemm_fp4_hopper():
example_dequant_gemm_fp4_hopper.main() example_dequant_gemm_fp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_mxfp4_hopper():
example_dequant_gemm_mxfp4_hopper.main()
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment