Unverified Commit 4bcb1593 authored by Cunxiao Ni's avatar Cunxiao Ni Committed by GitHub
Browse files

[Example] add w4a8 gemm kernel (#815)

* [Bugfix] fix autotune bug

* [Example] add w4a8 gemm kernel

* fix lint: pinned the version of `ml_dtypes`
The version of ml_dtypes should be pinned in the dependency specification. If the version of ml_dtypes is too low, it may result in errors such as fp4 not being defined.

* Renames example for dequantization GEMM

* format

* add w4a8 example to ci

* fix lint
parent 85d1a6b3
...@@ -110,11 +110,11 @@ def test_fp4_fp16_convert_close(): ...@@ -110,11 +110,11 @@ def test_fp4_fp16_convert_close():
def get_configs(): def get_configs():
block_M = [128] block_M = [64, 128]
block_N = [128, 256] block_N = [64, 128]
block_K = [128] block_K = [128, 256]
num_stages = [2] num_stages = [1, 2]
threads = [256] threads = [128, 256]
splits = [1] splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits)) _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
...@@ -239,11 +239,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -239,11 +239,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
if tune: if tune:
@autotune( @autotune(configs=get_configs(), warmup=10, rep=10)
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
warmup=10,
rep=10)
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def kernel(block_M=None, def kernel(block_M=None,
block_N=None, block_N=None,
...@@ -251,7 +247,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -251,7 +247,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
num_stages=None, num_stages=None,
threads=None, threads=None,
split=None): split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split) return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func
return kernel() return kernel()
else: else:
......
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from tvm import tir
import itertools
import torch
import argparse
def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "int8"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint8")
i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask
i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8"))
i8 = i8_shifted >> tir.const(4, "int8")
return i8
def get_configs():
iter_params = dict(
block_M=[64, 128],
block_N=[64, 128],
block_K=[128, 256],
num_stages=[1, 2],
threads=[128, 256, 512],
)
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@tilelang.jit(out_idx=[1])
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
def torch_convert(tensor):
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
i4_shifted = ((val >> (pos * 4)) & mask)
i4 = ((i4_shifted << 4) >> 4)
return i4.view(torch.int8)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
def ref_program(A, qB):
dtypeC = "int32"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_local_shape = (block_N, block_K)
assert K % (block_K) == 0
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[2])
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel
def main(m=128, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if (not tune):
kernel = matmul_int8xint4(
m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)(
block_M=32, block_N=32, block_K=128, num_stages=1, threads=128)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2)
print("All checks pass.")
latency = profiler.do_bench(warmup=50)
print(f"Tilelang: {latency} ms")
else:
best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Bset latency: {best_latency}")
print(f"Best config: {best_config}")
print(f"Best tflops: {total_flops / best_latency * 1e-9}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=512, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=512, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=512, help="Matrix dimension K")
parser.add_argument("--tune", action="store_true", help="Enable tuning")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
main(M, N, K, args.tune)
# main(M, N, K, True)
...@@ -4,6 +4,7 @@ import example_dequant_gemv_fp16xint4 ...@@ -4,6 +4,7 @@ import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper import example_dequant_gemm_bf16_mxfp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper_tma import example_dequant_gemm_bf16_mxfp4_hopper_tma
import example_dequant_gemm_w4a8
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -29,5 +30,10 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma(): ...@@ -29,5 +30,10 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma():
example_dequant_gemm_bf16_mxfp4_hopper_tma.main() example_dequant_gemm_bf16_mxfp4_hopper_tma.main()
@tilelang.testing.requires_cuda
def test_example_dequant_gemm_w4a8():
example_dequant_gemm_w4a8.main()
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -4,6 +4,6 @@ numpy>=1.23.5 ...@@ -4,6 +4,6 @@ numpy>=1.23.5
tqdm>=4.62.3 tqdm>=4.62.3
typing_extensions>=4.10.0 typing_extensions>=4.10.0
cloudpickle cloudpickle
ml_dtypes ml_dtypes>=0.5.3
psutil psutil
torch torch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment