Commit cde1886f authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Introduce quantize components of TileLang and add testing for...

[Refactor] Introduce quantize components of TileLang and add testing for dequant gemm exmaple (#494)

* Remove deprecated example_dequant_gemm.py and add DataType import in __init__.py

* lint fix

* lint fix

* Refactor dequantization examples to use tilelang imports and update data type handling in quantization utilities

* lint fix
parent 31dbb471
...@@ -433,5 +433,10 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): ...@@ -433,5 +433,10 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
256, 1024, 512, "float16", "float16", "float16", 3) 256, 1024, 512, "float16", "float16", "float16", 3)
def main():
test_run_dequantize_gemm()
test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4()
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() main()
...@@ -266,19 +266,12 @@ def ref_program(A, qB): ...@@ -266,19 +266,12 @@ def ref_program(A, qB):
return C.transpose(0, 1) return C.transpose(0, 1)
if __name__ == "__main__": def main(m=256, n=256, k=256, tune=False):
parser = argparse.ArgumentParser() total_flops = 2 * m * n * k
parser.add_argument('--m', type=int, default=256, help='M')
parser.add_argument('--n', type=int, default=256, help='N')
parser.add_argument('--k', type=int, default=256, help='K')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
total_flops = 2 * M * N * K
if (not args.tune): if (not tune):
program = matmul( program = matmul(
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)( m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
kernel = tilelang.compile(program, out_idx=[2]) kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
...@@ -291,10 +284,20 @@ if __name__ == "__main__": ...@@ -291,10 +284,20 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_result = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune) best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)
best_latency = best_result.latency best_latency = best_result.latency
best_config = best_result.config best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M')
parser.add_argument('--n', type=int, default=256, help='N')
parser.add_argument('--k', type=int, default=256, help='K')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
main(M, N, K, args.tune)
import tilelang
from tilelang import language as T
from typing import Optional, Callable, Any
import torch
from tilelang import DataType
from tilelang.quantize import (
_tir_packed_int_to_int_convert,)
def dequantize_gemv(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
num_bits: int = 4,
storage_dtype: str = "int8",
source_format: str = "uint",
n_partition: int = 4,
reduce_thread: int = 32,
fast_decoding: bool = False,
trans_A: bool = False,
trans_B: bool = True,
group_size: int = -1,
with_scaling: bool = False,
) -> Callable[..., Any]:
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"sch_outer_reduction_with_config is not implemented")
assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
storage_type = "".join(c for c in storage_dtype if not c.isdigit())
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits
MAX_TRANSACTION_SIZE_IN_BITS = 128
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
micro_size_k_compressed = micro_size_k // num_elems_per_byte
block_K = reduce_thread * micro_size_k
if group_size == -1:
group_size = K
A_shape = (M, K)
B_shape = (N, K // storage_nbit * num_bits)
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
import_source: Optional[str] = None
func_name: str = ""
if fast_decoding is True:
# Lazy import to decrease the startup time
# as intrin registry may take a while to load
from tilelang.quantize import get_lop3_intrin_group
lop3_intrin_info = get_lop3_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
with_scaling=with_scaling,
with_zeros=False,
)
import_source = lop3_intrin_info["c_source"]
func_name = lop3_intrin_info["func_name"]
assert import_source is not None, "lop3_intrin_info is not found"
assert func_name is not None, "lop3_intrin_info is not found"
import_source = import_source
@T.prim_func
def main(
A: T.Tensor[A_shape, in_dtype],
B: T.Tensor[B_shape, storage_dtype],
C: T.Tensor[C_shape, out_dtype],
):
with T.Kernel(
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
) as (
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([micro_size_k], in_dtype)
accum_res = T.alloc_local((1,), accum_dtype)
reduced_accum_res = T.alloc_local((1,), accum_dtype)
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x")
ni = T.thread_binding(0, n_partition, thread="threadIdx.y")
T.import_source(import_source)
T.clear(accum_res)
for ko in T.serial(T.ceildiv(K, block_K)):
for v in T.vectorized(micro_size_k):
A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]
for v in T.vectorized(micro_size_k_compressed):
B_quant_local[v] = B[
bx * n_partition + ni,
ko * (reduce_thread * micro_size_k_compressed) +
kr * micro_size_k_compressed + v,
]
if fast_decoding:
T.call_extern(
func_name,
T.address_of(B_quant_local[0]),
T.address_of(B_dequantize_local[0]),
dtype=in_dtype,
)
else:
for ki in T.serial(micro_size_k):
B_dequantize_local[ki] = _tir_packed_int_to_int_convert(
storage_type,
storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte],
ki % num_elems_per_byte, in_dtype)
if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size):
T.dp4a(
A_local[ki * dp4a_size],
B_dequantize_local[ki * dp4a_size],
accum_res[0],
)
else:
for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki] * B_dequantize_local[ki]
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
accum_res[0],
True,
reduced_accum_res[0],
kr,
dtype="handle",
))
if kr == 0:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
return main
def main() -> None:
M = 1
N = 1024
K = 1024
in_dtype = "float16"
out_dtype = "float16"
accum_dtype = "float16"
num_bits = 4
storage_dtype = "int8"
source_format = "uint"
n_partition = 4
reduce_thread = 32
fast_decoding = True
trans_A = False
trans_B = True
group_size = -1
with_scaling = False
program = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype,
source_format, n_partition, reduce_thread, fast_decoding, trans_A,
trans_B, group_size, with_scaling)
kernel = tilelang.compile(program)
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits
A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda()
qB = torch.randint(
0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda()
if fast_decoding:
from tilelang.quantize.utils import interleave_weight
qB = interleave_weight(qB, num_bits, in_dtype)
kernel(A, qB, C)
# int4 reference
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for j in range(B.shape[1]):
B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
print("C: ", C)
print("Ref C: ", ref_c)
# doesn't apply scaling, the absolute error is large
torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1)
if __name__ == "__main__":
main()
import tilelang.testing
import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper
@tilelang.testing.requires_cuda
def test_example_dequant_gemv_fp16xint4():
example_dequant_gemv_fp16xint4.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_fp4_hopper():
example_dequant_gemm_fp4_hopper.main()
if __name__ == "__main__":
tilelang.testing.main()
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
past_len = seq_kv - seq_q
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
return output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=4096, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
batch, heads, seq_q, seq_kv, dim, is_causal = args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not args.tune):
program = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)(
block_M=128, block_N=128, num_stages=2, threads=256)
ref_program = partial(ref_program, is_causal=is_causal)
kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
...@@ -58,6 +58,7 @@ from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401 ...@@ -58,6 +58,7 @@ from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401
import tvm import tvm
import tvm._ffi.base import tvm._ffi.base
from tvm import DataType # noqa: F401
from . import libinfo from . import libinfo
......
from .quantization import (
_tir_packed_int_to_int_convert, # noqa: F401
_tir_packed_to_signed_convert, # noqa: F401
_tir_packed_to_unsigned_convert, # noqa: F401
_tir_packed_to_fp4_to_f16, # noqa: F401
_tir_u8_to_f8_e4m3_to_f16, # noqa: F401
_tir_packed_to_unsigned_convert_with_zeros, # noqa: F401
)
from .utils import (
gen_quant4, # noqa: F401
general_compress, # noqa: F401
interleave_weight, # noqa: F401
)
from .lop3 import get_lop3_intrin_group # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Dict, Literal
decode_i4_to_f16 = """
template <typename T1, typename T2, bool isSigned = false>
__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast<uint *>(_i4s);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
}
template <typename T1, typename T2>
__device__ void decode_i4s_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8)
{
decode_i4b_to_f16<T1, T2, true>(_i4s, B_local_decode, N);
}
template <typename T1, typename T2>
__device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8)
{
decode_i4b_to_f16<T1, T2, false>(_i4u, B_local_decode, N);
}
"""
decode_i4_to_f16_scale = """
template <typename T1, typename T2, typename T3, bool isSigned = false, bool withScaling = false>
__device__ void decode_i4b_to_f16_scale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast<uint *>(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
decode_i4b_to_f16_scale<T1, T2, T3, true, true>(_i4s, B_local_decode, N, scale);
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
decode_i4b_to_f16_scale<T1, T2, T3, false, true>(_i4u, B_local_decode, N, scale);
}
"""
decode_i4_to_f16_scale_offset = """
template <typename T1, typename T2, typename T3, bool isSigned = false, bool withScaling = false>
__device__ void decode_i4b_to_f16_scale_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const int offset = 0)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast<uint *>(_i4s);
T3 const scale_l = *scale;
T3 const scale_r = *(scale + offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0));
}
#pragma unroll
for (int i = (N / 4); i < (N / 2); i++)
{
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0));
}
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i4s_to_f16_scale_offset(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int offset = 0, const int N = 8)
{
decode_i4b_to_f16_scale_offset<T1, T2, T3, true, true>(_i4s, B_local_decode, N, scale, offset);
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i4u_to_f16_scale_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int offset = 0, const int N = 8)
{
decode_i4b_to_f16_scale_offset<T1, T2, T3, false, true>(_i4u, B_local_decode, N, scale, offset);
}
"""
decode_i4_to_f16_scale_zeros_original = """
template <typename T1, typename T2, typename T3, typename T4, bool isSigned = false>
__device__ void decode_i4b_to_f16_zeros_original(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast<uint *>(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
// input zeros maybe int32(qzeros) or half format
T4 const zero_r = *zeros;
uint const packed_zeros = __pack_half2(zero_r, zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template <typename T1, typename T2, typename T3, typename T4>
__device__ void decode_i4u_to_f16_scale_zeros_original(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i4b_to_f16_zeros_original<T1, T2, T3, T4, false>(_i4u, B_local_decode, N, scale, zeros);
}
"""
decode_i4_to_f16_scale_zeros_original_offset = """
template <typename T1, typename T2, typename T3, typename T4, bool isSigned = false>
__device__ void decode_i4b_to_f16_zeros_original_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr, const int offset = 0)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast<uint *>(_i4s);
T3 const scale_l = *scale;
T3 const scale_r = *(scale + offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r);
// input zeros maybe int32(qzeros) or half format
T3 const zeros_l = *zeros;
T3 const zeros_r = *(zeros + offset);
uint const packed_zeros_l = __pack_half2(zeros_l, zeros_l);
uint const packed_zeros_r = __pack_half2(zeros_r, zeros_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_l));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0));
}
#pragma unroll
for (int i = (N / 4); i < (N / 2); i++)
{
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_r));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0));
}
}
template <typename T1, typename T2, typename T3, typename T4>
__device__ void decode_i4u_to_f16_scale_zeros_original_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int offset = 0, const int N = 8)
{
decode_i4b_to_f16_zeros_original_offset<T1, T2, T3, T4, false>(_i4u, B_local_decode, N, scale, zeros, offset);
}
"""
decode_i4_to_f16_scale_zeros_rescale = """
template <typename T1, typename T2, typename T3, typename T4, bool isSigned = false>
__device__ void decode_i4b_to_f16_scale_zeros_rescale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast<uint *>(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
T4 const zero_r = *zeros;
uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros));
}
}
template <typename T1, typename T2, typename T3, typename T4>
__device__ void decode_i4u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i4b_to_f16_scale_zeros_rescale<T1, T2, T3, T4, false>(_i4u, B_local_decode, N, scale, zeros);
}
"""
decode_i4_to_f16_scale_zeros_rescale_offset = """
template <typename T1, typename T2, typename T3, typename T4, bool isSigned = false>
__device__ void decode_i4b_to_f16_scale_zeros_rescale_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr, const int offset = 0)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast<uint *>(_i4s);
T3 const scale_l = *scale;
T3 const scale_r = *(scale + offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r);
// input zeros maybe int32(qzeros) or half format
T3 const zeros_l = *zeros;
T3 const zeros_r = *(zeros + offset);
uint const packed_zeros_l = 0x80008000 | __pack_half2(zeros_l, zeros_l);
uint const packed_zeros_r = 0x80008000 | __pack_half2(zeros_r, zeros_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(packed_zeros_l));
}
#pragma unroll
for (int i = (N / 4); i < (N / 2); i++)
{
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(packed_zeros_r));
}
}
template <typename T1, typename T2, typename T3, typename T4>
__device__ void decode_i4u_to_f16_scale_zeros_rescale_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int offset = 0, const int N = 8)
{
decode_i4b_to_f16_scale_zeros_rescale_offset<T1, T2, T3, T4, false>(_i4u, B_local_decode, N, scale, zeros, offset);
}
"""
decode_i4_to_f16_scale_zeros_quantized = """
template <typename T1, typename T2, typename T3, typename T4, bool isSigned = false>
__device__ void decode_i4b_to_f16_scale_zeros_quantized(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
uint const i4s = *reinterpret_cast<uint *>(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
// input zeros maybe int32(qzeros) or half format
int16_t const zero_r = *((int16_t*)zeros);
uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template <typename storage_dtype, typename target_dtype, typename scale_dtype, typename zero_dtype>
__device__ void decode_i4u_to_f16_scale_zeros_quantized(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, zero_dtype *zeros = nullptr, const int N = 8)
{
decode_i4b_to_f16_scale_zeros_quantized<storage_dtype, target_dtype, scale_dtype, zero_dtype, false>(_i4u, B_local_decode, N, scale, zeros);
}
"""
decode_i4_to_f16_scale_zeros_quantized_offset = """
template <typename T1, typename T2, typename T3, bool isSigned = false>
__device__ void decode_i4b_to_f16_scale_zeros_quantized_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T1 *qzeros = nullptr, const int scale_offset = 0, const int qzeros_offset = 0, const int group_offset = 0)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
uint const i4s = *reinterpret_cast<uint *>(_i4s);
T3 const scale_l = *scale;
T3 const scale_r = *(scale + scale_offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r);
const int num_elems_per_storage_dtype = sizeof(T1) * 8 / 4;
T1 const qzeros_l = *qzeros;
T1 const qzeros_r = *(qzeros + qzeros_offset);
int16_t const zero_l = (qzeros_l >> (group_offset * 4) & 0xf);
int16_t const zero_r = (qzeros_r >> (group_offset * 4) & 0xf);
uint median_num_l = ((0xe400 | zero_l) << 16) | (0xe400 | zero_l);
uint median_num_r = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
}
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num_l));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0));
}
#pragma unroll
for (int i = (N / 4); i < (N / 2); i++)
{
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num_r));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0));
}
}
template <typename storage_dtype, typename target_dtype, typename scale_dtype>
__device__ void decode_i4u_to_f16_scale_zeros_quantized_offset(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, storage_dtype *qzeros = nullptr, const int scale_offset = 0, const int zero_offset = 0, const int group_offset = 0, const int N = 8)
{
decode_i4b_to_f16_scale_zeros_quantized_offset<storage_dtype, target_dtype, scale_dtype, false>(_i4u, B_local_decode, N, scale, qzeros, scale_offset, zero_offset, group_offset);
}
"""
decode_i2_to_f16 = """
template <typename T1, typename T2, bool isSigned = false>
__device__ void decode_i2b_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
}
template <typename T1, typename T2>
__device__ void decode_i2s_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8)
{
decode_i2b_to_f16<T1, T2, true>(_i2s, B_local_decode, N);
}
template <typename T1, typename T2>
__device__ void decode_i2u_to_f16(T1 *_i2u, T2 *B_local_decode, const int N = 8)
{
decode_i2b_to_f16<T1, T2, false>(_i2u, B_local_decode, N);
}
"""
decode_i2_to_f16_scale = """
template <typename T1, typename T2, typename T3, bool isSigned = false>
__device__ void decode_i2b_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0));
}
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i2s_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale, const int N = 8)
{
decode_i2b_to_f16_scale<T1, T2, T3, true>(_i2s, B_local_decode, scale, N);
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i2u_to_f16_scale(T1 *_i2u, T2 *B_local_decode, T3 *scale, const int N = 8)
{
decode_i2b_to_f16_scale<T1, T2, T3, false>(_i2u, B_local_decode, scale, N);
}
"""
decode_i2_to_f16_scale_zeros_original_offset = """
template <typename T1, typename T2, typename T3, bool isSigned = false>
__device__ void decode_i2b_to_f16_scale_zeros_original_offset(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int offset = 0, const int N = 8)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
T3 const zeros_l = *zeros;
T3 const zeros_r = *(zeros + offset);
uint const packed_zeros_l = __pack_half2(zeros_l, zeros_l);
uint const packed_zeros_r = __pack_half2(zeros_r, zeros_r);
T3 const scale_l = *scale;
T3 const scale_r = *(scale + offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_l));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0));
}
#pragma unroll
for (int i = (N / 4); i < (N / 2); i++)
{
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_r));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0));
}
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i2u_to_f16_scale_zeros_original_offset(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int offset = 0, const int N = 8)
{
decode_i2b_to_f16_scale_zeros_original<T1, T2, T3, false>(_i2u, B_local_decode, scale, zeros, offset, N);
}
"""
decode_i2_to_f16_scale_zeros_original = """
template <typename T1, typename T2, typename T3, bool isSigned = false>
__device__ void decode_i2b_to_f16_scale_zeros_original(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros)));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0));
}
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i2u_to_f16_scale_zeros_original(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8)
{
decode_i2b_to_f16_scale_zeros_original<T1, T2, T3, false>(_i2u, B_local_decode, scale, zeros, N);
}
"""
decode_i2_to_f16_scale_zeros_rescale = """
template <typename T1, typename T2, typename T3, bool isSigned = false>
__device__ void decode_i2b_to_f16_scale_zeros_rescale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros)));
}
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i2u_to_f16_scale_zeros_rescale(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8)
{
decode_i2b_to_f16_scale_zeros_rescale<T1, T2, T3, false>(_i2u, B_local_decode, scale, zeros, N);
}
"""
decode_i2_to_f16_scale_zeros_quantized = """
template <typename T1, typename T2, typename T3, typename T4, bool isSigned = false>
__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
int16_t const zero_r = *((int16_t*)zeros);
uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template <typename T1, typename T2, typename T3, typename T4>
__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i2b_to_f16_scale_zeros_quantized<T1, T2, T3, T4, false>(_i2u, B_local_decode, N, scale, zeros);
}
"""
decode_i1_to_f16 = """
/*
Kind 0: original
Kind 1: rescale
Kind 2: quantized
# documents for zeros_mode:
# original: target = (dequantize_weight - zero_point) * scale
# rescale: target = dequantize_weight * scale - zero_point
# quantized: target = (dequantize_weight - dequantize_zeros) * scale
# Notice: only support "original" and "rescale" now
zeros_mode: Literal["original", "rescale", "quantized"] = "original"
*/
template <typename T1, typename T2, bool isSigned = false, bool withScaling = false, bool withZeros = false, int ZerosKind = 1>
__device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8, half *scale = nullptr, half *zeros = nullptr)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
if constexpr (isSigned)
{
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i]));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT));
}
if constexpr (withZeros && ZerosKind == 0)
{
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros)));
}
if constexpr (withScaling)
{
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0));
}
if constexpr (withZeros && ZerosKind == 1)
{
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros)));
}
}
}
template <typename T1, typename T2>
__device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8)
{
decode_i1b_to_f16<T1, T2, true>(_i1s, B_local_decode, N);
}
template <typename T1, typename T2>
__device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8)
{
decode_i1b_to_f16<T1, T2, false>(_i1u, B_local_decode, N);
}
"""
decode_i1_to_f16_scale = """
template <typename T1, typename T2, typename T3>
__device__ void decode_i1u_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i]));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
"""
decode_i1_to_f16_scale_zeros_original = """
template <typename T1, typename T2, typename T3, typename T4, bool isSigned = false>
__device__ void decode_i1b_to_f16_zeros_original(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
// input zeros maybe int32(qzeros) or half format
T4 const zero_r = *zeros;
uint const packed_zeros = __pack_half2(zero_r, zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template <typename T1, typename T2, typename T3, typename T4>
__device__ void decode_i1u_to_f16_scale_zeros_original(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i1b_to_f16_zeros_original<T1, T2, T3, T4, false>(_i1u, B_local_decode, N, scale, zeros);
}
"""
decode_i1_to_f16_scale_zeros_rescale = """
template <typename T1, typename T2, typename T3, typename T4, bool isSigned = false>
__device__ void decode_i1b_to_f16_scale_zeros_rescale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
T4 const zero_r = *zeros;
uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros));
}
}
template <typename T1, typename T2, typename T3, typename T4>
__device__ void decode_i1u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i1b_to_f16_scale_zeros_rescale<T1, T2, T3, T4, false>(_i4u, B_local_decode, N, scale, zeros);
}
"""
decode_i1s_to_i8s = """template <typename T1, typename T2>
__device__ void decode_i1s_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16)
{
int i8s[4];
// vector load
*reinterpret_cast<int4 *>(i8s) = *reinterpret_cast<int4 *>(_i8s);
int16_t i1b_i16 = *reinterpret_cast<int16_t *>(_i1b);
// permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15}
// into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x}
int i1b = (i1b_i16 & 0x0f0f);
i1b |= ((i1b_i16 & 0xf0f0) << 12);
// i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0}
// interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// First, we extract the i1b and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1
static constexpr uint I8s_MAGIC_NUM = 0x00000000;
static constexpr uint TRANSFORM_SUBTRACT = 0xffffffff; // for signed int 2x - 1
for (int i = 0; i < N / 4; i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vadd4(i8s[i], i8s[i]);
i8s[i] = __vadd4(i8s[i], TRANSFORM_SUBTRACT);
}
*reinterpret_cast<int4 *>(_i8s) = *reinterpret_cast<int4 *>(i8s);
}
template <typename T1, typename T2>
__device__ void decode_i1u_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16)
{
int *i8s = reinterpret_cast<int *>(_i8s);
int16_t i1b_i16 = *reinterpret_cast<int16_t *>(_i1b);
// permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15}
// into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x}
int i1b = (i1b_i16 & 0x0f0f);
i1b |= ((i1b_i16 & 0xf0f0) << 12);
// i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0}
// interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// First, we extract the i1b and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1
static constexpr uint I8s_MAGIC_NUM = 0x00000000;
static constexpr uint MEDIAN_NUM = 0x00000000;
for (int i = 0; i < N / 4; i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
}
}
"""
decode_i2s_to_i8s = """template <typename T1, typename T2>
__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16)
{
// convert 8 int2b_t to 8 int8b_t -> 2 int32
uint *i8s = reinterpret_cast<uint *>(_i8s);
// i2b = {e7,e6,e5,e4,e3,e2,e1,e0}
// also require interleave {e7,e3,e6,e2,e5,e1,e4,e0}
uint const i2b = *reinterpret_cast<uint *>(_i2b);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024
static constexpr uint MEDIAN_NUM = 0x02020202;
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vsub4(i8s[i], MEDIAN_NUM);
}
}
template <typename T1, typename T2>
__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16)
{
// convert 8 int2b_t to 8 int8b_t -> 2 int32
uint *i8s = reinterpret_cast<uint *>(_i8s);
// i2b = {e7,e6,e5,e4,e3,e2,e1,e0}
// also require interleave {e7,e3,e6,e2,e5,e1,e4,e0}
uint const i2b = *reinterpret_cast<uint *>(_i2b);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
}
}
"""
decode_i4s_to_i8s = """template <typename T1, typename T2>
__device__ void decode_i4s_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16)
{
uint *i8s = reinterpret_cast<uint *>(_i8s);
uint *i4b = reinterpret_cast<uint *>(_i4b);
// First, we extract the i4s and construct an intermediate i8 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12
static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0
static constexpr uint MEDIAN_NUM = 0x07070707;
#pragma unroll
for (int i = 0; i < (N / 8); i++)
{
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i + 2])
: "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vsubss4(i8s[i], MEDIAN_NUM);
i8s[i + 2] = __vsubss4(i8s[i + 2], MEDIAN_NUM);
}
}
template <typename T1, typename T2>
__device__ void decode_i4u_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16)
{
uint *i8s = reinterpret_cast<uint *>(_i8s);
uint *i4b = reinterpret_cast<uint *>(_i4b);
// First, we extract the i4s and construct an intermediate i8 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12
static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0
#pragma unroll
for (int i = 0; i < (N / 8); i++)
{
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i + 2])
: "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut));
}
}
"""
decode_i2s_to_i4s = r"""
template <typename T1, typename T2, bool isSigned>
__device__ void decode_i2b_to_i4s(T1 *_i2b, T2 *_i4s, const int N = 16)
{
uint *i4s = reinterpret_cast<uint *>(_i4s);
uint *i2b = reinterpret_cast<uint *>(_i2b);
// First, we extract the i4s and construct an intermediate i8 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x33333333; // 0xf -> 0b1111 select 0,2,4,6,8,10,12
static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0
static constexpr uint MEDIAN_NUM = isSigned ? 0x33333333 : 0x00000000;
#pragma unroll
for (int i = 0; i < (N / 8); i++)
{
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i4s[i])
: "r"(i2b[i / 2] >> (2 * (i % 2))), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut));
if constexpr (isSigned)
{
// TODO(lei): uint4 sub should be enhanced.
// 0x03 0x03 0x03 0x03
// i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i];
}
}
}
template <typename T1, typename T2>
__device__ void decode_i2s_to_i4s(T1 *_i4s, T2 *B_local_decode, const int N = 16)
{
decode_i2b_to_i4s<T1, T2, true>(_i4s, B_local_decode, N);
}
template <typename T1, typename T2>
__device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16)
{
decode_i2b_to_i4s<T1, T2, false>(_i4u, B_local_decode, N);
}
"""
def get_lop3_intrin_group(
out_dtype: Literal["float16", "int8", "int4"],
source_format: Literal["int", "uint"] = "uint",
source_bit: int = 4,
storage_dtype: Literal["int32", "int8"] = "int8",
with_scaling: bool = False,
with_zeros: bool = False,
zeros_mode: Literal["original", "rescale", "quantized"] = "original",
storage_scope: str = "local",
) -> Dict[str, str]:
"""
This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding.
LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of
intrinsic operations that can be performed on these inputs. This function retrieves and returns this group.
Parameters
----------
in_dtype : Literal["int8"]
The data type of the input. It should be "int8".
out_dtype : Literal["float16", "int8", "int4"]
The data type of the output. It can be either "float16" or "int8" or "int4".
storage_nbit : int, optional
The number of bits used for storage. By default, it is 4.
with_scale : bool, optional
A boolean parameter that indicates whether scaling should be applied. By default, it is False.
with_zeros : bool, optional
A boolean parameter that indicates whether zeros should be used. By default, it is False.
zeros_mode : Literal["original", "rescale", "quantized"], optional
The mode of zeros. It can be either "original", "rescale", or "quantized". By default, it is "original".
storage_scope : Literal["local", "warp"], optional
The scope of the storage. It can be either "local" or "warp". By default, it is "local".
Returns
-------
Dict[str, str]
A dictionary mapping the names of the intrinsics to their corresponding implementations.
"""
assert out_dtype in [
"float16", "int8", "int4"
], (f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' .")
dtype_mapping = {"float16": "f16", "int4": "i4", "int8": "i8", "int32": "i32"}
target_dtype = dtype_mapping[out_dtype]
if source_format not in ["int", "uint"]:
raise ValueError(
f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}.")
if with_zeros and source_format == "int":
raise ValueError(f"Zeros are not supported for signed integers, but got {source_format}")
source_symbol = "i" if source_format == "int" else "u"
import_c_map = {
"i4_to_f16": decode_i4_to_f16,
"i2_to_f16": decode_i2_to_f16,
"i1_to_f16": decode_i1_to_f16,
"i4_to_f16_scale": decode_i4_to_f16_scale,
"i4_to_f16_scale_offset": decode_i4_to_f16_scale_offset,
"i2_to_f16_scale": decode_i2_to_f16_scale,
"i1_to_f16_scale": decode_i1_to_f16_scale,
"i4_to_f16_scale_zeros_original": decode_i4_to_f16_scale_zeros_original,
"i4_to_f16_scale_zeros_original_offset": decode_i4_to_f16_scale_zeros_original_offset,
"i2_to_f16_scale_zeros_original": decode_i2_to_f16_scale_zeros_original,
"i1_to_f16_scale_zeros_original": decode_i1_to_f16_scale_zeros_original,
"i4_to_f16_scale_zeros_rescale": decode_i4_to_f16_scale_zeros_rescale,
"i4_to_f16_scale_zeros_rescale_offset": decode_i4_to_f16_scale_zeros_rescale_offset,
"i2_to_f16_scale_zeros_rescale": decode_i2_to_f16_scale_zeros_rescale,
"i1_to_f16_scale_zeros_rescale": decode_i1_to_f16_scale_zeros_rescale,
"i4_to_f16_scale_zeros_quantized": decode_i4_to_f16_scale_zeros_quantized,
"i2_to_f16_scale_zeros_quantized": decode_i2_to_f16_scale_zeros_quantized,
"i4_to_f16_scale_zeros_quantized_offset": decode_i4_to_f16_scale_zeros_quantized_offset,
"i1_to_i8": decode_i1s_to_i8s,
"i2_to_i8": decode_i2s_to_i8s,
"i4_to_i8": decode_i4s_to_i8s,
"i2_to_i4": decode_i2s_to_i4s,
}
key = f"i{source_bit}_to_{target_dtype}"
if with_scaling:
key += "_scale"
if with_zeros:
key += f"_zeros_{zeros_mode}"
is_ladder_stage3 = (storage_scope == "warp") and with_scaling
if is_ladder_stage3:
key += "_offset"
if out_dtype == "float16":
d4f = "f16"
elif out_dtype == "int8":
d4f = "i8s"
elif out_dtype == "int4":
d4f = "i4s"
else:
raise ValueError("Unsupported target dtype: {}".format(target_dtype))
source_symbol = "u" if source_format == "uint" else "s"
func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f)
if with_scaling:
func_name += "_scale"
if with_zeros:
func_name += f"_zeros_{zeros_mode}"
if is_ladder_stage3:
func_name += "_offset"
return {
"func_name": func_name,
"c_source": import_c_map[key],
}
# Copyright 2018 The apache/tvm Authors. All Rights Reserved.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
# The code below is mostly copied from mlc.ai quantization.py in mlc-llm.
# pylint: disable=invalid-name,missing-function-docstring,unused-variable
"""TIR computation utilities for quantization."""
from tilelang import tvm as tvm
from tvm import tir
# fmt: off
def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True):
mask = tir.const((1 << 16) - 1, "uint32")
res = []
for data in [v0, v1]:
u32_val = tir.reinterpret("uint32", data)
if round_to_even:
rounding_bias = ((u32_val >> tir.const(16, "uint32"))
& tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32")
u32_val += rounding_bias
res.append((u32_val >> tir.const(16, "uint32")) & mask)
return res[0] | (res[1] << tir.const(16, "uint32"))
def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr):
mask = tir.const((1 << 16) - 1, "uint32")
x0 = x & mask
x1 = (x >> 16) & mask
return (tir.reinterpret("float32", x << tir.const(16, "uint32")) for x in [x0, x1])
def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == "uint32"
mask = tvm.tir.const((1 << nbit) - 1, "uint32")
return tir.Cast(dtype, (val >> (pos * nbit).astype("uint32")) & mask)
def _tir_packed_uint_to_uint_to_float(storage_nbit: int):
storage_dtype = "uint" + str(storage_nbit)
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
max_int_value = (1 << (nbit - 1)) - 1
return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const(
(1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype)
return f_convert
def _tir_packed_int_to_int_to_float(storage_nbit: int):
storage_dtype = "int" + str(storage_nbit)
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tir.const((1 << nbit) - 1, "int32")
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
return tir.Cast(
dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))
return f_convert
def _tir_f32_to_uint_to_f4(val: tir.PrimExpr):
assert val.dtype == "float32"
val_u32 = tir.reinterpret("uint32", val)
# e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7)
# e_f32 == 120 -> e_f4 = 1
# e_f32 < 120 -> e_f4 = 0
m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32")
e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32")
s = (val_u32 >> tir.const(31, "uint32"))
e_f4 = tir.Select(
e_f32 > tir.const(120, "uint32"),
tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")),
tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"),
tir.const(0, "uint32")))
return (s << tir.const(3, "uint32")) | e_f4
def _tir_f16_to_uint_to_f4(val: tir.PrimExpr):
assert val.dtype == "float16"
val_u32 = tir.Cast("uint32", tir.reinterpret("uint16", val))
m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32")
e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32")
s = (val_u32 >> tir.const(15, "uint32"))
e_f4 = tir.Select(
e_f16 > tir.const(8, "uint32"),
tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")),
tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32")))
return (s << tir.const(3, "uint32")) | e_f4
def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float32"
assert val.dtype == "uint32"
# e_f4 == 0 -> e_f32 = 0
# e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2
mask = tvm.tir.const((1 << nbit) - 1, "uint32")
f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask
s = f4 >> tir.const(3, "uint32")
e_f4 = f4 & tir.const(7, "uint32")
e_f32 = e_f4 | tir.const(120, "uint32")
val_f32 = tir.reinterpret("float32",
(e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32"))
return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32)
def _tir_packed_to_fp4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint32"
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
mask = tvm.tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16")
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
return tir.Select(e_f4 == tir.const(0, "uint16"), tir.const(0, "float16"), val_f16)
def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)
def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tvm.tir.const((1 << nbit) - 1, storage_dtype)
f4 = ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(storage_dtype)
f4 = (val >> (pos.astype(storage_dtype) * tir.const(nbit, storage_dtype))) & mask
s = f4 >> tir.const(3, storage_dtype)
e_f4 = f4 & tir.const(7, storage_dtype)
e_f16 = e_f4 | tir.const(8, storage_dtype)
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, storage_dtype))) << tir.const(10, storage_dtype)).astype("uint16"))
return tir.Select(e_f4 == tir.const(0, storage_dtype), tir.const(0, "float16"), val_f16)
return f_convert
def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
e4 = val & tir.const(0x40, "uint16")
prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"),
tir.const(0x4000, "uint16"))
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix
return tir.reinterpret("float16", s_f16 | e_f16)
def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
e4 = val & tir.const(0x40, "uint16")
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16"))
e_f16 = e_f16 ^ tir.const(0x2000, "uint16")
return tir.reinterpret("float16", s_f16 | e_f16)
def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
return tir.reinterpret("e5m2_float8", val).astype("float16")
def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
max_int_value = (1 << (nbit - 1))
return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const(
(1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype)
return f_convert
def _tir_packed_to_unsigned_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)
def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tvm.tir.const((1 << nbit) - 1, storage_dtype)
return ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(dtype)
return f_convert
def _tir_packed_to_unsigned_convert_with_zeros(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)
def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm.tir.PrimExpr,
dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tvm.tir.const((1 << nbit) - 1, storage_dtype)
return (((val >> (pos * nbit).astype(storage_dtype)) & mask) - zero).astype(dtype)
return f_convert
def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tir.const((1 << nbit) - 1, "int32")
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
return tir.Cast(
dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))
return f_convert
# fmt: on
def gen_quant4(k, n, groupsize=-1):
import torch
import torch.nn as nn
maxq = 2**4
w = torch.randn((k, n), dtype=torch.half, device="cpu")
original_w = w.clone()
if groupsize == -1:
groupsize = k
if groupsize != -1:
w = w.reshape((-1, groupsize, n))
w = w.permute(1, 0, 2)
w = w.reshape((groupsize, -1))
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / maxq
# Quantize.
w = torch.round(w / s).int()
# Unsigned storage.
w += (maxq) // 2
w = torch.clamp(w, 0, maxq)
# Dequantize.
ref = (w - (maxq) // 2).half() * s
if groupsize != -1:
def reshape(w):
w = w.reshape((groupsize, -1, n))
w = w.permute(1, 0, 2)
w = w.reshape((k, n)).contiguous()
return w
ref = reshape(ref)
w = reshape(w)
s = s.reshape((-1, n)).contiguous()
linear = nn.Linear(k, n, bias=False)
linear.weight.data = ref.t()
return original_w, linear, s, (w - (maxq) // 2)
def general_compress(lowprecision_weight, source_bits=4, storage_dtype=None):
import torch
if storage_dtype is None:
storage_dtype = torch.int8
elems_per_byte = 8 // source_bits
if lowprecision_weight.dtype == torch.float16:
lowprecision_weight = lowprecision_weight.to(torch.int8)
int8_weight = torch.zeros(
(*lowprecision_weight.shape[:-1], lowprecision_weight.shape[-1] // elems_per_byte),
dtype=torch.int8,
device=lowprecision_weight.device)
for j in range(lowprecision_weight.shape[-1] // elems_per_byte):
for k in range(elems_per_byte):
int8_weight[..., j] |= (lowprecision_weight[..., j * elems_per_byte + k] <<
(source_bits * k)).to(torch.int8)
return int8_weight.to(storage_dtype)
# interleave weight numpy implementation
def interleave_weight(qweight, nbits=4, target_dtype="float16"):
"""Interleave the weight to the target data type.
Args:
qweight (_type_): _description_
nbits (int, optional): _description_. Defaults to 4.
target_dtype (str, optional): _description_. Defaults to "float16".
Returns:
_type_: _description_
Example:
qweight = torch.randint(0, 127, (10, 10), dtype=torch.int8).cuda()
interleave_weight(qweight, 4, "float16")
"""
import torch
assert target_dtype in ["float16", "int8"]
# reinterpret the data type of qweight to int32
qweight = qweight.view(torch.int32)
new_qweight = torch.zeros_like(qweight)
bits_stride = 8 if target_dtype == "int8" else 16
mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
num_groups = 32 // bits_stride
elems_per_group = bits_stride // nbits
for i in range(num_groups):
for j in range(elems_per_group):
offset = i * elems_per_group + j
shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
if nbits == 1 and target_dtype == "int8":
# special handling for 1b interleave
n16_weight = new_qweight & torch.int32(0xF0F00F0F)
n16_weight |= ((new_qweight & torch.int32(0x000000F0)) >> 4) << 16
n16_weight |= ((new_qweight & torch.int32(0x0000F000)) >> 12) << 24
n16_weight |= ((new_qweight & torch.int32(0x000F0000)) >> 16) << 4
n16_weight |= ((new_qweight & torch.int32(0x0F000000)) >> 24) << 12
return n16_weight.view(torch.int8)
elif nbits == 2 and target_dtype == "float16":
n8_weight = new_qweight & torch.int32(0xFF0000FF)
n8_weight |= ((new_qweight & torch.int32(0x0000FF00)) >> 8) << 16
n8_weight |= ((new_qweight & torch.int32(0x00FF0000)) >> 16) << 8
return n8_weight.view(torch.int8)
elif nbits == 1 and target_dtype == "float16":
n8_weight = new_qweight & torch.int32(0xF000000F)
n8_weight |= ((new_qweight & torch.int32(0x000000F0)) >> 4) << 8
n8_weight |= ((new_qweight & torch.int32(0x00000F00)) >> 8) << 16
n8_weight |= ((new_qweight & torch.int32(0x0000F000)) >> 12) << 24
n8_weight |= ((new_qweight & torch.int32(0x000F0000)) >> 16) << 4
n8_weight |= ((new_qweight & torch.int32(0x00F00000)) >> 20) << 12
n8_weight |= ((new_qweight & torch.int32(0x0F000000)) >> 24) << 20
return n8_weight.view(torch.int8)
return new_qweight.view(torch.int8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment