"docs/source/models.decoder.rst" did not exist on "b374cc7b4e40373b505b8ed73908beec782254f5"
Unverified Commit 45e3a7bc authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

use sgl_per_token_group_quant_fp8 kernel (#3493)

parent b96e92e6
......@@ -25,7 +25,7 @@ runtime_common = [
]
srt = [
"sglang[runtime_common]", "cuda-python",
"sgl-kernel>=0.0.3.post3", "torch", "vllm>=0.6.4.post1,<=0.7.2",
"sgl-kernel>=0.0.3.post4", "torch", "vllm>=0.6.4.post1,<=0.7.2",
"flashinfer_python>=0.2.0.post2", "outlines>=0.0.44,<=0.1.11"
]
......
......@@ -33,6 +33,10 @@ _is_rocm = torch.cuda.is_available() and torch.version.hip
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
if _is_cuda or _is_rocm:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
......@@ -488,7 +492,10 @@ def invoke_fused_moe_kernel(
else:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
if _is_cuda:
A, A_scale = sglang_per_token_group_quant_fp8(A, block_k)
else:
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
......
......@@ -27,6 +27,10 @@ from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
_is_cuda = torch.cuda.is_available() and torch.version.cuda
if _is_cuda:
from sgl_kernel import sgl_per_token_group_quant_fp8
logger = logging.getLogger(__name__)
......@@ -135,6 +139,36 @@ def per_token_group_quant_fp8(
return x_q, x_s
def sglang_per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
):
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment