Unverified Commit a086a113 authored by lambert0312's avatar lambert0312 Committed by GitHub
Browse files

Use sgl-kernel sgl_per_token_group_quant_int8 (#4971)

parent bdbe5f81
...@@ -755,6 +755,9 @@ def invoke_fused_moe_kernel( ...@@ -755,6 +755,9 @@ def invoke_fused_moe_kernel(
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
) )
from sglang.srt.layers.quantization.int8_kernel import (
sglang_per_token_group_quant_int8,
)
else: else:
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
...@@ -794,7 +797,10 @@ def invoke_fused_moe_kernel( ...@@ -794,7 +797,10 @@ def invoke_fused_moe_kernel(
# activation block-wise int8 quantization # activation block-wise int8 quantization
assert len(block_shape) == 2 assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1] block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k) if _is_cuda:
A, A_scale = sglang_per_token_group_quant_int8(A, block_k)
else:
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
......
...@@ -8,7 +8,11 @@ import torch ...@@ -8,7 +8,11 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.utils import get_device_name from sglang.srt.utils import get_device_name, is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import sgl_per_token_group_quant_int8
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -165,6 +169,33 @@ def per_token_group_quant_int8( ...@@ -165,6 +169,33 @@ def per_token_group_quant_int8(
return x_q, x_s return x_q, x_s
def sglang_per_token_group_quant_int8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.int8,
):
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_max = iinfo.max
int8_min = iinfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
return x_q, x_s
@triton.jit @triton.jit
def _w8a8_block_int8_matmul( def _w8a8_block_int8_matmul(
# Pointers to inputs and output # Pointers to inputs and output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment