Commit cde1886f authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Introduce quantize components of TileLang and add testing for...

[Refactor] Introduce quantize components of TileLang and add testing for dequant gemm exmaple (#494)

* Remove deprecated example_dequant_gemm.py and add DataType import in __init__.py

* lint fix

* lint fix

* Refactor dequantization examples to use tilelang imports and update data type handling in quantization utilities

* lint fix
parent 31dbb471
......@@ -433,5 +433,10 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
256, 1024, 512, "float16", "float16", "float16", 3)
def main():
test_run_dequantize_gemm()
test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4()
if __name__ == "__main__":
tilelang.testing.main()
main()
......@@ -266,19 +266,12 @@ def ref_program(A, qB):
return C.transpose(0, 1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M')
parser.add_argument('--n', type=int, default=256, help='N')
parser.add_argument('--k', type=int, default=256, help='K')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
total_flops = 2 * M * N * K
def main(m=256, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if (not args.tune):
if (not tune):
program = matmul(
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(
m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
......@@ -291,10 +284,20 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)
best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M')
parser.add_argument('--n', type=int, default=256, help='N')
parser.add_argument('--k', type=int, default=256, help='K')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
main(M, N, K, args.tune)
import tilelang
from tilelang import language as T
from typing import Optional, Callable, Any
import torch
from tilelang import DataType
from tilelang.quantize import (
_tir_packed_int_to_int_convert,)
def dequantize_gemv(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
num_bits: int = 4,
storage_dtype: str = "int8",
source_format: str = "uint",
n_partition: int = 4,
reduce_thread: int = 32,
fast_decoding: bool = False,
trans_A: bool = False,
trans_B: bool = True,
group_size: int = -1,
with_scaling: bool = False,
) -> Callable[..., Any]:
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"sch_outer_reduction_with_config is not implemented")
assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
storage_type = "".join(c for c in storage_dtype if not c.isdigit())
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits
MAX_TRANSACTION_SIZE_IN_BITS = 128
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
micro_size_k_compressed = micro_size_k // num_elems_per_byte
block_K = reduce_thread * micro_size_k
if group_size == -1:
group_size = K
A_shape = (M, K)
B_shape = (N, K // storage_nbit * num_bits)
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
import_source: Optional[str] = None
func_name: str = ""
if fast_decoding is True:
# Lazy import to decrease the startup time
# as intrin registry may take a while to load
from tilelang.quantize import get_lop3_intrin_group
lop3_intrin_info = get_lop3_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
with_scaling=with_scaling,
with_zeros=False,
)
import_source = lop3_intrin_info["c_source"]
func_name = lop3_intrin_info["func_name"]
assert import_source is not None, "lop3_intrin_info is not found"
assert func_name is not None, "lop3_intrin_info is not found"
import_source = import_source
@T.prim_func
def main(
A: T.Tensor[A_shape, in_dtype],
B: T.Tensor[B_shape, storage_dtype],
C: T.Tensor[C_shape, out_dtype],
):
with T.Kernel(
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
) as (
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([micro_size_k], in_dtype)
accum_res = T.alloc_local((1,), accum_dtype)
reduced_accum_res = T.alloc_local((1,), accum_dtype)
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x")
ni = T.thread_binding(0, n_partition, thread="threadIdx.y")
T.import_source(import_source)
T.clear(accum_res)
for ko in T.serial(T.ceildiv(K, block_K)):
for v in T.vectorized(micro_size_k):
A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]
for v in T.vectorized(micro_size_k_compressed):
B_quant_local[v] = B[
bx * n_partition + ni,
ko * (reduce_thread * micro_size_k_compressed) +
kr * micro_size_k_compressed + v,
]
if fast_decoding:
T.call_extern(
func_name,
T.address_of(B_quant_local[0]),
T.address_of(B_dequantize_local[0]),
dtype=in_dtype,
)
else:
for ki in T.serial(micro_size_k):
B_dequantize_local[ki] = _tir_packed_int_to_int_convert(
storage_type,
storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte],
ki % num_elems_per_byte, in_dtype)
if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size):
T.dp4a(
A_local[ki * dp4a_size],
B_dequantize_local[ki * dp4a_size],
accum_res[0],
)
else:
for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki] * B_dequantize_local[ki]
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
accum_res[0],
True,
reduced_accum_res[0],
kr,
dtype="handle",
))
if kr == 0:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
return main
def main() -> None:
M = 1
N = 1024
K = 1024
in_dtype = "float16"
out_dtype = "float16"
accum_dtype = "float16"
num_bits = 4
storage_dtype = "int8"
source_format = "uint"
n_partition = 4
reduce_thread = 32
fast_decoding = True
trans_A = False
trans_B = True
group_size = -1
with_scaling = False
program = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype,
source_format, n_partition, reduce_thread, fast_decoding, trans_A,
trans_B, group_size, with_scaling)
kernel = tilelang.compile(program)
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits
A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda()
qB = torch.randint(
0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda()
if fast_decoding:
from tilelang.quantize.utils import interleave_weight
qB = interleave_weight(qB, num_bits, in_dtype)
kernel(A, qB, C)
# int4 reference
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for j in range(B.shape[1]):
B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
print("C: ", C)
print("Ref C: ", ref_c)
# doesn't apply scaling, the absolute error is large
torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1)
if __name__ == "__main__":
main()
import tilelang.testing
import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper
@tilelang.testing.requires_cuda
def test_example_dequant_gemv_fp16xint4():
example_dequant_gemv_fp16xint4.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_fp4_hopper():
example_dequant_gemm_fp4_hopper.main()
if __name__ == "__main__":
tilelang.testing.main()
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
past_len = seq_kv - seq_q
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
return output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=4096, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
batch, heads, seq_q, seq_kv, dim, is_causal = args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not args.tune):
program = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)(
block_M=128, block_N=128, num_stages=2, threads=256)
ref_program = partial(ref_program, is_causal=is_causal)
kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -58,6 +58,7 @@ from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401
import tvm
import tvm._ffi.base
from tvm import DataType # noqa: F401
from . import libinfo
......
from .quantization import (
_tir_packed_int_to_int_convert, # noqa: F401
_tir_packed_to_signed_convert, # noqa: F401
_tir_packed_to_unsigned_convert, # noqa: F401
_tir_packed_to_fp4_to_f16, # noqa: F401
_tir_u8_to_f8_e4m3_to_f16, # noqa: F401
_tir_packed_to_unsigned_convert_with_zeros, # noqa: F401
)
from .utils import (
gen_quant4, # noqa: F401
general_compress, # noqa: F401
interleave_weight, # noqa: F401
)
from .lop3 import get_lop3_intrin_group # noqa: F401
This diff is collapsed.
# Copyright 2018 The apache/tvm Authors. All Rights Reserved.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
# The code below is mostly copied from mlc.ai quantization.py in mlc-llm.
# pylint: disable=invalid-name,missing-function-docstring,unused-variable
"""TIR computation utilities for quantization."""
from tilelang import tvm as tvm
from tvm import tir
# fmt: off
def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True):
mask = tir.const((1 << 16) - 1, "uint32")
res = []
for data in [v0, v1]:
u32_val = tir.reinterpret("uint32", data)
if round_to_even:
rounding_bias = ((u32_val >> tir.const(16, "uint32"))
& tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32")
u32_val += rounding_bias
res.append((u32_val >> tir.const(16, "uint32")) & mask)
return res[0] | (res[1] << tir.const(16, "uint32"))
def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr):
mask = tir.const((1 << 16) - 1, "uint32")
x0 = x & mask
x1 = (x >> 16) & mask
return (tir.reinterpret("float32", x << tir.const(16, "uint32")) for x in [x0, x1])
def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == "uint32"
mask = tvm.tir.const((1 << nbit) - 1, "uint32")
return tir.Cast(dtype, (val >> (pos * nbit).astype("uint32")) & mask)
def _tir_packed_uint_to_uint_to_float(storage_nbit: int):
storage_dtype = "uint" + str(storage_nbit)
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
max_int_value = (1 << (nbit - 1)) - 1
return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const(
(1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype)
return f_convert
def _tir_packed_int_to_int_to_float(storage_nbit: int):
storage_dtype = "int" + str(storage_nbit)
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tir.const((1 << nbit) - 1, "int32")
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
return tir.Cast(
dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))
return f_convert
def _tir_f32_to_uint_to_f4(val: tir.PrimExpr):
assert val.dtype == "float32"
val_u32 = tir.reinterpret("uint32", val)
# e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7)
# e_f32 == 120 -> e_f4 = 1
# e_f32 < 120 -> e_f4 = 0
m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32")
e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32")
s = (val_u32 >> tir.const(31, "uint32"))
e_f4 = tir.Select(
e_f32 > tir.const(120, "uint32"),
tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")),
tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"),
tir.const(0, "uint32")))
return (s << tir.const(3, "uint32")) | e_f4
def _tir_f16_to_uint_to_f4(val: tir.PrimExpr):
assert val.dtype == "float16"
val_u32 = tir.Cast("uint32", tir.reinterpret("uint16", val))
m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32")
e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32")
s = (val_u32 >> tir.const(15, "uint32"))
e_f4 = tir.Select(
e_f16 > tir.const(8, "uint32"),
tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")),
tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32")))
return (s << tir.const(3, "uint32")) | e_f4
def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float32"
assert val.dtype == "uint32"
# e_f4 == 0 -> e_f32 = 0
# e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2
mask = tvm.tir.const((1 << nbit) - 1, "uint32")
f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask
s = f4 >> tir.const(3, "uint32")
e_f4 = f4 & tir.const(7, "uint32")
e_f32 = e_f4 | tir.const(120, "uint32")
val_f32 = tir.reinterpret("float32",
(e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32"))
return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32)
def _tir_packed_to_fp4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint32"
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
mask = tvm.tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16")
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
return tir.Select(e_f4 == tir.const(0, "uint16"), tir.const(0, "float16"), val_f16)
def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)
def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tvm.tir.const((1 << nbit) - 1, storage_dtype)
f4 = ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(storage_dtype)
f4 = (val >> (pos.astype(storage_dtype) * tir.const(nbit, storage_dtype))) & mask
s = f4 >> tir.const(3, storage_dtype)
e_f4 = f4 & tir.const(7, storage_dtype)
e_f16 = e_f4 | tir.const(8, storage_dtype)
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, storage_dtype))) << tir.const(10, storage_dtype)).astype("uint16"))
return tir.Select(e_f4 == tir.const(0, storage_dtype), tir.const(0, "float16"), val_f16)
return f_convert
def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
e4 = val & tir.const(0x40, "uint16")
prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"),
tir.const(0x4000, "uint16"))
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix
return tir.reinterpret("float16", s_f16 | e_f16)
def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
e4 = val & tir.const(0x40, "uint16")
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16"))
e_f16 = e_f16 ^ tir.const(0x2000, "uint16")
return tir.reinterpret("float16", s_f16 | e_f16)
def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
return tir.reinterpret("e5m2_float8", val).astype("float16")
def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
max_int_value = (1 << (nbit - 1))
return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const(
(1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype)
return f_convert
def _tir_packed_to_unsigned_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)
def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tvm.tir.const((1 << nbit) - 1, storage_dtype)
return ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(dtype)
return f_convert
def _tir_packed_to_unsigned_convert_with_zeros(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)
def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm.tir.PrimExpr,
dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tvm.tir.const((1 << nbit) - 1, storage_dtype)
return (((val >> (pos * nbit).astype(storage_dtype)) & mask) - zero).astype(dtype)
return f_convert
def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tir.const((1 << nbit) - 1, "int32")
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
return tir.Cast(
dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))
return f_convert
# fmt: on
def gen_quant4(k, n, groupsize=-1):
import torch
import torch.nn as nn
maxq = 2**4
w = torch.randn((k, n), dtype=torch.half, device="cpu")
original_w = w.clone()
if groupsize == -1:
groupsize = k
if groupsize != -1:
w = w.reshape((-1, groupsize, n))
w = w.permute(1, 0, 2)
w = w.reshape((groupsize, -1))
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / maxq
# Quantize.
w = torch.round(w / s).int()
# Unsigned storage.
w += (maxq) // 2
w = torch.clamp(w, 0, maxq)
# Dequantize.
ref = (w - (maxq) // 2).half() * s
if groupsize != -1:
def reshape(w):
w = w.reshape((groupsize, -1, n))
w = w.permute(1, 0, 2)
w = w.reshape((k, n)).contiguous()
return w
ref = reshape(ref)
w = reshape(w)
s = s.reshape((-1, n)).contiguous()
linear = nn.Linear(k, n, bias=False)
linear.weight.data = ref.t()
return original_w, linear, s, (w - (maxq) // 2)
def general_compress(lowprecision_weight, source_bits=4, storage_dtype=None):
import torch
if storage_dtype is None:
storage_dtype = torch.int8
elems_per_byte = 8 // source_bits
if lowprecision_weight.dtype == torch.float16:
lowprecision_weight = lowprecision_weight.to(torch.int8)
int8_weight = torch.zeros(
(*lowprecision_weight.shape[:-1], lowprecision_weight.shape[-1] // elems_per_byte),
dtype=torch.int8,
device=lowprecision_weight.device)
for j in range(lowprecision_weight.shape[-1] // elems_per_byte):
for k in range(elems_per_byte):
int8_weight[..., j] |= (lowprecision_weight[..., j * elems_per_byte + k] <<
(source_bits * k)).to(torch.int8)
return int8_weight.to(storage_dtype)
# interleave weight numpy implementation
def interleave_weight(qweight, nbits=4, target_dtype="float16"):
"""Interleave the weight to the target data type.
Args:
qweight (_type_): _description_
nbits (int, optional): _description_. Defaults to 4.
target_dtype (str, optional): _description_. Defaults to "float16".
Returns:
_type_: _description_
Example:
qweight = torch.randint(0, 127, (10, 10), dtype=torch.int8).cuda()
interleave_weight(qweight, 4, "float16")
"""
import torch
assert target_dtype in ["float16", "int8"]
# reinterpret the data type of qweight to int32
qweight = qweight.view(torch.int32)
new_qweight = torch.zeros_like(qweight)
bits_stride = 8 if target_dtype == "int8" else 16
mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
num_groups = 32 // bits_stride
elems_per_group = bits_stride // nbits
for i in range(num_groups):
for j in range(elems_per_group):
offset = i * elems_per_group + j
shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
if nbits == 1 and target_dtype == "int8":
# special handling for 1b interleave
n16_weight = new_qweight & torch.int32(0xF0F00F0F)
n16_weight |= ((new_qweight & torch.int32(0x000000F0)) >> 4) << 16
n16_weight |= ((new_qweight & torch.int32(0x0000F000)) >> 12) << 24
n16_weight |= ((new_qweight & torch.int32(0x000F0000)) >> 16) << 4
n16_weight |= ((new_qweight & torch.int32(0x0F000000)) >> 24) << 12
return n16_weight.view(torch.int8)
elif nbits == 2 and target_dtype == "float16":
n8_weight = new_qweight & torch.int32(0xFF0000FF)
n8_weight |= ((new_qweight & torch.int32(0x0000FF00)) >> 8) << 16
n8_weight |= ((new_qweight & torch.int32(0x00FF0000)) >> 16) << 8
return n8_weight.view(torch.int8)
elif nbits == 1 and target_dtype == "float16":
n8_weight = new_qweight & torch.int32(0xF000000F)
n8_weight |= ((new_qweight & torch.int32(0x000000F0)) >> 4) << 8
n8_weight |= ((new_qweight & torch.int32(0x00000F00)) >> 8) << 16
n8_weight |= ((new_qweight & torch.int32(0x0000F000)) >> 12) << 24
n8_weight |= ((new_qweight & torch.int32(0x000F0000)) >> 16) << 4
n8_weight |= ((new_qweight & torch.int32(0x00F00000)) >> 20) << 12
n8_weight |= ((new_qweight & torch.int32(0x0F000000)) >> 24) << 20
return n8_weight.view(torch.int8)
return new_qweight.view(torch.int8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment