Unverified Commit 24603e4a authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[Feature] Low-bit twiddling dequantization and FP4 GEMM (#725)



* [Dequant] Add bit-twiddling dequantize cuda for fp4-->bf16

* [Dequant] Add extern call and serial dequantization

* [Dequant] Parallel Dequant wait for fence debug.

* [Scale] Add scale matrix to mxfp4 gemm

* [Remove] Remove fence-buggy example and some generated source cuda code

* [MXFP4] Update initial version of MXFP4 GEMM

* [Scale] Add scale to latest mxfp4 gemm

* [Lint]

* [BugFix] Load Scale, disabe TMA to recover performance

* [Lint]

* [Lint]

* [Scale] Use L2 to hold Scale and enable TMA will slightly boost performance

* [Lint]

* Update example_dequant_gemm_bf16_fp4_hopper_serial.py

* Remove deprecated dequantization examples for BF16 and MXFP4 in the dequantize_gemm directory.

* Refactor dequantization examples for improved readability and consistency. Adjusted formatting in matmul function and added spacing for clarity. Updated function signatures and comments for better understanding.

* Refactor index_to_coordinates usage in bitnet example and update dequantization example configurations. Removed the custom index_to_coordinates function and replaced it with the built-in version. Adjusted block_K parameter in dequantization example for consistency.

* lint fix

* ci fix

* Remove non-existent example

* [BugFix] Add smem swizzle to recover performance of TMA

* [BugFix] Enough reg for producer when threads=512

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent a86223f4
...@@ -9,7 +9,6 @@ from tilelang import tvm as tvm ...@@ -9,7 +9,6 @@ from tilelang import tvm as tvm
from tvm import DataType from tvm import DataType
from tilelang.intrinsics.mma_layout import ( from tilelang.intrinsics.mma_layout import (
make_mma_swizzle_layout as make_swizzle_layout,) make_mma_swizzle_layout as make_swizzle_layout,)
from tilelang.intrinsics.utils import index_to_coordinates
import numpy as np import numpy as np
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
...@@ -200,7 +199,7 @@ def bitnet_158_int8xint2_prefill( ...@@ -200,7 +199,7 @@ def bitnet_158_int8xint2_prefill(
index = ( index = (
i * threads * local_size_compressed + i * threads * local_size_compressed +
thread_bindings * local_size_compressed + v) thread_bindings * local_size_compressed + v)
vi, vj = index_to_coordinates(index, B_shared_shape) vi, vj = T.index_to_coordinates(index, B_shared_shape)
B_local[v] = B_shared[vi, vj] B_local[v] = B_shared[vi, vj]
T.call_extern( T.call_extern(
...@@ -212,7 +211,7 @@ def bitnet_158_int8xint2_prefill( ...@@ -212,7 +211,7 @@ def bitnet_158_int8xint2_prefill(
for v in T.vectorized(0, local_size): for v in T.vectorized(0, local_size):
index = (i * threads * local_size + thread_bindings * local_size + v) index = (i * threads * local_size + thread_bindings * local_size + v)
vi, vj = index_to_coordinates(index, B_dequantize_shared_shape) vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape)
B_dequantize_shared[vi, vj] = B_dequantize_local[v] B_dequantize_shared[vi, vj] = B_dequantize_local[v]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
......
import tilelang
import tilelang.language as T
from tilelang import tvm as tvm
from tvm import DataType
from tvm import tir
import torch
from utils import torch_convert_bit_twiddling, torch_convert
def get_configs():
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[128],
num_stages=[0, 2],
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
},
)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
fast_dequant=True,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shape = (M, K)
B_shape = (N, QK)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
use_twiddling=True,
)
import_source = mxfp_intrin_info["c_source"]
func_name = mxfp_intrin_info["func_name"]
assert import_source is not None, "mxfp_intrin_info is not found"
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits
local_compress_size = local_size // num_elems_per_byte
@T.macro
def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared):
# import fast_dequantize plugin
T.import_source(import_source)
tx = T.get_thread_binding()
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register.
# Prepare for dequant.
for v in T.vectorized(0, local_compress_size):
index = i * threads * local_compress_size + tx * local_compress_size + v
B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK]
# Then, dequant.
T.call_extern(
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
dtype=out_dtype,
)
# Finally, store the dequantized data to shared memory.
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr,
scale: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret(
"bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
return val_bf16
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared):
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_shared[i, j // num_elems_per_byte],
j % num_elems_per_byte,
0, # No scale for test
dtype=out_dtype,
)
T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.annotate_layout({
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared)
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])
return main
def ref_program_twiddling(A, qB):
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
def ref_program_simple(A, qB):
dtypeC = "bfloat16"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
total_flops = 2 * m * n * k
if tune:
kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
fast_dequant=fast_dequant,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
if fast_dequant:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
main(256, 256, 256, True)
main(256, 256, 256, False)
import tilelang
import tilelang.language as T
from tilelang import tvm as tvm
from tvm import DataType
from tvm import tir
import torch
from utils import torch_convert_bit_twiddling, torch_convert
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str):
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
return val_bf16
def get_configs():
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[64, 128, 256],
num_stages=[0, 2],
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1],)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
scale_size=32,
fast_dequant=True,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shape = (M, K)
B_shape = (N, QK)
Scale_shape = (N, K // scale_size)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
use_twiddling=True,
)
import_source = mxfp_intrin_info["c_source"]
func_name = mxfp_intrin_info["func_name"]
assert import_source is not None, "mxfp_intrin_info is not found"
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits
local_compress_size = local_size // num_elems_per_byte
@T.macro
def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k):
# import fast_dequantize plugin
T.import_source(import_source)
tx = T.get_thread_binding()
bx = T.get_block_binding(0)
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
Scale_local_thread = T.alloc_local((1,), storage_dtype)
Scale_local_thread_exponent = T.alloc_local((1,), "float32")
for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register.
# Prepare for dequant.
index_base = i * threads * local_compress_size + tx * local_compress_size
for v in T.vectorized(0, local_compress_size):
index = index_base + v
B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK]
index_scale = index_base // (scale_size // num_elems_per_byte)
si = index_scale // (block_K // scale_size)
sj = index_scale % (block_K // scale_size)
Scale_local_thread[0] = Scale[bx * block_N + si, k * block_K // scale_size + sj]
Scale_local_thread_exponent[0] = T.exp2(
T.cast(Scale_local_thread[0] - 127, "float"))
# Then, dequant.
T.call_extern(
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
dtype=out_dtype,
)
# Finally, store the dequantized data to shared memory.
for v in T.Parallel(local_size):
B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0]
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
bx = T.get_block_binding(0)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale[
bx * block_N + i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
if threads == 512:
T.no_set_max_nreg()
T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale, k)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale, k)
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])
return main
def ref_program_twiddling(A, qB, Scale):
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
def ref_program_simple(A, qB, Scale):
dtypeC = "bfloat16"
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
total_flops = 2 * m * n * k
if tune:
kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
if fast_dequant:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
M, N, K = 256, 256, 256
scale_size = 32
main(M, N, K, scale_size, fast_dequant=True)
main(M, N, K, scale_size, fast_dequant=False)
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from tvm import tir
import argparse
import itertools
import torch
tilelang.disable_cache()
torch.manual_seed(0)
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str):
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
return val_bf16
def torch_convert(tensor, scale_size=None, Scale=None):
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
def _convert(val, pos, scale=None):
assert val.dtype == torch.uint8
# val = val.view(torch.int8)
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = (f4 & 6) >> 1
e_f16 = e_f4 + 126
if scale is not None:
e_f16 = min(e_f16 + scale, (1 << 8) - 1)
m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.bfloat16)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
if scale_size is not None:
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2, Scale[i][j // scale_size])
else:
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
@tilelang.jit(out_idx=[-1])
def convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
0, # No scale for test
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
@tilelang.jit(out_idx=[-1])
def convert_scale(N, K, block_N, block_K, in_dtype, num_bits=4, scale_size=32, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
Scale_shape = (N, K // scale_size)
Scale_shared_shape = (block_N, block_K // scale_size)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Scale_shared = T.alloc_shared(Scale_shared_shape, storage_dtype)
Scale_local = T.alloc_fragment(Scale_shared_shape, storage_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
T.copy(Scale[bx * block_N, k * block_K // scale_size], Scale_shared)
T.copy(Scale_shared, Scale_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_local[
i, j //
scale_size], # Scale is the exponential part, within the representation of uint8
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
def test_fp4_bf16_convert_close():
N, K = 256, 256
block_N, block_K = 64, 64
kernel = convert(
N,
K,
block_N,
block_K,
"bfloat16",
)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = kernel(B)
ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Convert Pass")
def test_fp4_bf16_convert_scale_close():
N, K = 256, 256
block_N, block_K = 64, 64
kernel = convert_scale(N, K, block_N, block_K, "bfloat16", scale_size=32)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
Scale = torch.randint(0, 1, (N, K // 32), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = kernel(B, Scale)
ref_out = torch_convert(B, scale_size=32, Scale=Scale)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Convert Scale Pass")
def get_configs():
block_M = [128]
block_N = [128, 256]
block_K = [128]
num_stages = [2]
threads = [256]
splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[4],
'split': c[5]
} for c in _configs]
return configs
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, scale_size=32, tune=False):
@tilelang.jit(out_idx=[-1])
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
Scale_shape = (N, K // scale_size)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
Scale_shared_shape = (block_N, block_K // scale_size)
assert K % (block_K * split) == 0
KK = K // split
@T.prim_func
def main_split(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
SplitC = T.alloc_buffer([
split, (N + block_N - 1) // block_N * block_N,
(M + block_M - 1) // block_M * block_M
], out_dtype)
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split,
threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
Scale_shared = T.alloc_shared(Scale_shared_shape, storage_dtype)
Scale_local = T.alloc_fragment(Scale_shared_shape, storage_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
Scale_shared: tilelang.layout.make_swizzled_layout(Scale_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
T.copy(A[by * block_M, KK * bz + k * block_K], A_shared)
T.copy(B[bx * block_N, (KK * bz + k * block_K) // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
T.copy(Scale[bx * block_N, (KK * bz + k * block_K) // scale_size], Scale_shared)
T.copy(Scale_shared, Scale_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_local[i, j // scale_size],
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
acc = T.alloc_fragment((block_N, block_M), out_dtype)
T.clear(acc)
for k in range(split):
for i, j in T.Parallel(block_N, block_M):
acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j]
T.copy(acc, Ct[bx * block_N, by * block_M])
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
Scale_shared = T.alloc_shared((block_N, block_K // scale_size), storage_dtype)
Scale_local = T.alloc_fragment((block_N, block_K // scale_size), storage_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
Scale_shared: tilelang.layout.make_swizzled_layout(Scale_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
T.copy(Scale[bx * block_N, k * block_K // scale_size], Scale_shared)
T.copy(Scale_shared, Scale_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_local[i, j // scale_size],
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
if split == 1:
return main
else:
return main_split
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
warmup=10,
rep=10)
@tilelang.jit(out_idx=[-1])
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
threads=None,
split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads, split=1):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel
def ref_program(A, qB):
dtypeC = "bfloat16"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def ref_program_scale(A, qB, Scale):
dtypeC = "bfloat16"
B = torch_convert(qB, scale_size=32, Scale=Scale)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def main(m=256, n=256, k=256, scale_size=32, tune=False):
total_flops = 2 * m * n * k
if (not tune):
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler.assert_allclose(ref_program_scale, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_scale, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
def test_convert():
test_fp4_bf16_convert_close()
test_fp4_bf16_convert_scale_close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M')
parser.add_argument('--n', type=int, default=256, help='N')
parser.add_argument('--k', type=int, default=256, help='K')
parser.add_argument(
'--scale_size',
type=int,
default=32,
help='scale size, the exponential part, within the representation of uint8')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
# test_convert()
main(M, N, K, args.scale_size, args.tune)
...@@ -2,7 +2,7 @@ import tilelang.testing ...@@ -2,7 +2,7 @@ import tilelang.testing
import example_dequant_gemv_fp16xint4 import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_mxfp4_hopper import example_dequant_gemm_bf16_fp4_hopper_serial
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -18,8 +18,8 @@ def test_example_dequant_gemm_fp4_hopper(): ...@@ -18,8 +18,8 @@ def test_example_dequant_gemm_fp4_hopper():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_mxfp4_hopper(): def test_example_dequant_gemm_bf16_fp4_hopper_serial():
example_dequant_gemm_mxfp4_hopper.main() example_dequant_gemm_bf16_fp4_hopper_serial.main()
if __name__ == "__main__": if __name__ == "__main__":
......
import torch
def torch_convert_bit_twiddling(tensor):
def _convert(val0, val1, pos) -> torch.bfloat16:
assert val0.dtype == torch.uint8
assert val1.dtype == torch.uint8
val0 = val0.view(torch.uint8)
val1 = val1.view(torch.uint8)
val_concat = (val0.item() << 8) | val1.item()
mask = 0b1000000111000000
if pos == 0:
bf16 = val_concat & mask
elif pos == 1:
bf16 = (val_concat << 3) & mask
elif pos == 2:
bf16 = (val_concat << 6) & mask
elif pos == 3:
mask1 = 0b1000000000000000
mask2 = 0b0000000110000000
mask3 = 0b0000000001000000
bf16 = ((val_concat << 1) & mask1) | ((val_concat >> 3) & mask2) | (
(val_concat >> 7) & mask3)
bf16_new = torch.tensor([bf16], dtype=torch.uint16, device=val0.device).view(torch.bfloat16)
# Add bias for change from fp4 to bf16
bf16_new = bf16_new.item() * (2**126)
return bf16_new
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 4 * 2], tensor[i][j // 4 * 2 + 1], j % 4)
return new_tensor
def torch_convert(tensor, scale_size=None, Scale=None):
def _convert(val, pos, scale=None):
assert val.dtype == torch.uint8
# val = val.view(torch.int8)
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = (f4 & 6) >> 1
e_f16 = e_f4 + 126
if scale is not None:
e_f16 = min(e_f16 + scale, (1 << 8) - 1)
m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.bfloat16)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
if scale_size is not None:
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2, Scale[i][j // scale_size])
else:
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
import argparse import argparse
import torch
import itertools import itertools
import tilelang as tl import tilelang as tl
import tilelang.language as T import tilelang.language as T
...@@ -8,6 +7,7 @@ from tilelang.carver.template import MatmulTemplate ...@@ -8,6 +7,7 @@ from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
import torch
def ref_program(A, B): def ref_program(A, B):
......
#!/bin/bash
# Set ROOT_DIR to the project root (two levels up from this script's directory)
ROOT_DIR=$(cd "$(dirname "$0")/../.." && pwd)
# Change to the project root directory for local testing of changes
cd $ROOT_DIR
# Add the project root to PYTHONPATH so Python can find local modules
export PYTHONPATH=$ROOT_DIR:$PYTHONPATH
# Run pytest in parallel (4 workers) for all tests in the examples directory
cd examples
python -m pytest -n 4 .
cd ..
# Run pytest in parallel (4 workers) for all tests in the testing/python directory
cd testing/python
python -m pytest -n 4 .
cd ..
...@@ -81,24 +81,3 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]): ...@@ -81,24 +81,3 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]):
if dtype in {"float8_e4m3", "float8_e5m2", "int8"}: if dtype in {"float8_e4m3", "float8_e5m2", "int8"}:
micro_size_k = 32 micro_size_k = 32
return micro_size_x, micro_size_y, micro_size_k return micro_size_x, micro_size_y, micro_size_k
def index_to_coordinates(index, shape):
'''
General Implementation of:
vjj = index % (micro_size_k // num_elems_per_byte)
coordinates[-1] = index % shape[-1];
vii = index // (micro_size_k // num_elems_per_byte) % micro_size_y
index = index // shape[-1]; coordinates[-2] = index % shape[-2];
vj = index // (micro_size_k // num_elems_per_byte * micro_size_y) % block_K // (micro_size_k // num_elems_per_byte)
index = index // shape[-2]; coordinates[-3] = index % shape[-3];
vi = index // (micro_size_k // num_elems_per_byte * micro_size_y * (block_K // (micro_size_k // num_elems_per_byte))) % block_N // micro_size_y
index = index // shape[-3]; coordinates[-4] = index % shape[-4];
'''
coordinates = []
dims = len(shape)
for i in range(dims):
coordinates.append(index % shape[dims - i - 1])
index = index // shape[dims - i - 1]
coordinates.reverse()
return coordinates
...@@ -69,6 +69,7 @@ from .logical import any_of, all_of # noqa: F401 ...@@ -69,6 +69,7 @@ from .logical import any_of, all_of # noqa: F401
from .builtin import * # noqa: F401 from .builtin import * # noqa: F401
from .memscope import * # noqa: F401 from .memscope import * # noqa: F401
from .utils import index_to_coordinates # noqa: F401
def symbolic(name: str, dtype: str = "int32"): def symbolic(name: str, dtype: str = "int32"):
......
...@@ -7,7 +7,7 @@ from tvm import tir ...@@ -7,7 +7,7 @@ from tvm import tir
from typing import Any from typing import Any
from tilelang.language.kernel import get_thread_bindings from tilelang.language.kernel import get_thread_bindings
from tilelang.language import copy, macro, serial, alloc_shared from tilelang.language import copy, macro, serial, alloc_shared
from tilelang.intrinsics.utils import index_to_coordinates from tilelang.language.utils import index_to_coordinates
@macro @macro
......
from tilelang import tvm as tvm
from tvm.tir import PrimExpr
def index_to_coordinates(index, shape) -> list[PrimExpr]:
"""
Convert a flat (linear) index to multi-dimensional coordinates for a given shape.
Example:
shape = (4, 5, 6)
index = 53
index_to_coordinates(53, (4, 5, 6)) -> [1, 3, 5]
# Explanation:
# 53 // (5*6) = 1 (1st coordinate)
# 53 % (5*6) = 23
# 23 // 6 = 3 (2nd coordinate)
# 23 % 6 = 5 (3rd coordinate)
Args:
index (int): The flat index to convert.
shape (tuple or list of int): The shape of the multi-dimensional array.
Returns:
list: A list of coordinates corresponding to each dimension.
"""
coordinates = []
dims = len(shape)
for i in range(dims):
coordinates.append(index % shape[dims - i - 1])
index = index // shape[dims - i - 1]
coordinates.reverse()
return coordinates
def linear_index(*args: PrimExpr) -> PrimExpr:
"""
Convert a list of coordinates to a flat (linear) index using strides.
Usage examples:
linear_index(i) -> i
linear_index(i, j) -> i * stride + j
linear_index(i, j, stride_j) -> i * stride_j + j
linear_index(i, j, k, stride_j, stride_k)
-> i * stride_j * stride_k + j * stride_k + k
Example for index = i * threads * local_size + tx * local_size + v:
Suppose you have i, tx, v as coordinates, and threads, local_size as strides:
linear_index(i, tx, v, threads, local_size) == i * threads * local_size + tx * local_size + v
"""
n = len(args)
if n == 0:
raise ValueError("At least one index is required")
if n == 1:
return args[0]
# The first part is indices, the second part is strides (starting from the second dimension)
# A simpler way: the number of strides = total number of arguments - number of indices
# Actually, the args are designed as indices... + strides..., and the number of strides = number of indices - 1
num_coords = (n + 1) // 2
coords = args[:num_coords]
strides = args[num_coords:]
if len(strides) != len(coords) - 1:
raise ValueError("Stride count must be one less than coordinate count")
linear = coords[0]
for idx, stride in zip(coords[1:], strides):
linear = linear * stride + idx
return linear
...@@ -14,3 +14,4 @@ from .utils import ( ...@@ -14,3 +14,4 @@ from .utils import (
) )
from .lop3 import get_lop3_intrin_group # noqa: F401 from .lop3 import get_lop3_intrin_group # noqa: F401
from .mxfp import get_mxfp_intrin_group # noqa: F401
from typing import Literal, Dict
# Implementation asm for fp4 to bf16, using twiddling
# Reference: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py#L11-L18
decode_f4_to_bf16_twiddling = """
// N should be the number of elements processed by one thread
template<typename T1, typename T2>
__device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, const int N = 8) {
#pragma unroll
for (int i = 0; i < N; ++i) {
uint B_dequantize_local_vec[4];
uint tmp, bias, d0, d1, d2, d3, d4, d5, d6;
asm volatile(
// To handle the endianness issue
"prmt.b32 %13, %4, 0, 0x0123;"
"mov.b32 %12, 0x7e807e80;"
"and.b32 %0, %13, 0b10000001110000001000000111000000;"
"mul.bf16x2 %0, %0, %12;"
"shl.b32 %1, %13, 3;"
"and.b32 %1, %1, 0b10000001110000001000000111000000;"
"mul.bf16x2 %1, %1, %12;"
"shl.b32 %2, %13, 6;"
"and.b32 %2, %2, 0b10000001110000001000000111000000;"
"mul.bf16x2 %2, %2, %12;"
"shl.b32 %5, %13, 1;"
"and.b32 %6, %5, 0b10000000000000001000000000000000;"
"shr.b32 %7, %13, 3;"
"and.b32 %8, %7, 0b00000001100000000000000110000000;"
"or.b32 %9, %6, %8;"
"shr.b32 %10, %13, 7;"
"and.b32 %11, %10, 0b00000000010000000000000001000000;"
"or.b32 %3, %9, %11;"
"mul.bf16x2 %3, %3, %12;"
:"=r"(B_dequantize_local_vec[0])
,"=r"(B_dequantize_local_vec[1])
,"=r"(B_dequantize_local_vec[2])
,"=r"(B_dequantize_local_vec[3])
:"r"(*(uint*)&B_local[i << 2]), "r"(d0), "r"(d1), "r"(d2), "r"(d3), "r"(d4), "r"(d5), "r"(d6), "r"(bias), "r"(tmp)
);
for (int j = 0; j < 4; ++j) {
// Pay attention to the big-endianness issue
B_local_decode[(i << 3) + j] = reinterpret_cast<T2*>(&B_dequantize_local_vec[j])[1];
B_local_decode[(i << 3) + j + 4] = reinterpret_cast<T2*>(&B_dequantize_local_vec[j])[0];
}
}
// Check if the synchronization is needed
}
"""
def get_mxfp_intrin_group(
out_dtype: Literal["float16", "bfloat16"] = "bfloat16",
source_format: Literal["int", "uint"] = "uint",
source_bit: int = 4,
storage_dtype: Literal["int32", "int8", "uint8"] = "uint8",
use_twiddling: bool = False,
) -> Dict[str, str]:
"""
This function is used to get the intrinsic group of the MXFP operation to avoid the overhead of fast decoding.
MXFP is a type of logic operation that takes three inputs. The intrinsic group refers to the set of
intrinsic operations that can be performed on these inputs. This function retrieves and returns this group.
"""
assert out_dtype in ["float16", "bfloat16"
], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'."
assert source_format in ["int", "uint"
], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'."
assert storage_dtype in [
"int32", "int8", "uint8"
], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'."
dtype_map = {"float16": "f16", "bfloat16": "bf16"}
key = f"fp{source_bit}_to_{dtype_map[out_dtype]}"
if use_twiddling:
key += "_twiddling"
import_c_map = {
"fp4_to_bf16_twiddling": decode_f4_to_bf16_twiddling,
}
func_name = f"decode_fp{source_bit}_to_{dtype_map[out_dtype]}"
if use_twiddling:
func_name += "_twiddling"
return {
"func_name": func_name,
"c_source": import_c_map[key],
}
...@@ -27,6 +27,26 @@ from tvm import tir ...@@ -27,6 +27,26 @@ from tvm import tir
# fmt: off # fmt: off
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str):
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
return val_bf16
def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True): def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True):
mask = tir.const((1 << 16) - 1, "uint32") mask = tir.const((1 << 16) - 1, "uint32")
res = [] res = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment