# Adapted from https://github.com/sgl-project/sglang/blob/4cb53ecd0cffceb6dee5c011a58f65997a86f151/python/sglang/srt/layers/quantization/int8_kernel.py
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
The outputs are dequantized tensor.
"""
block_n,block_k=block_size[0],block_size[1]
n,k=x_q_block.shape
n_tiles=(n+block_n-1)//block_n
k_tiles=(k+block_k-1)//block_k
assertn_tiles==x_s.shape[0]
assertk_tiles==x_s.shape[1]
x_dq_block=x_q_block.to(torch.float32)
foriinrange(k_tiles):
forjinrange(n_tiles):
x_dq_block[
j*block_n:min((j+1)*block_n,n),
i*block_k:min((i+1)*block_k,k),
]*=x_s[j][i]
returnx_dq_block
@triton.jit
def_per_token_quant_int8(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
BLOCK:tl.constexpr,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282