Commit 769353e6 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_rzc' into 'v0.5.4_dev'

支持w8a8 compile,注册自定义算子解决部分断图问题

See merge request OpenDAS/sglang!31
parents 0dc51b09 8da47f19
......@@ -420,7 +420,7 @@ class DCUMLABackend(AttentionBackend):
)
return o
@torch._dynamo.disable()
@torch._dynamo.disable() # NOTE: FP8 cache decode不支持compile
def forward_decode(
self,
q: torch.Tensor,
......@@ -475,7 +475,7 @@ class DCUMLABackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@torch._dynamo.disable() # NOTE: untested
@torch._dynamo.disable()
def forward_extend(
self,
q: torch.Tensor,
......
......@@ -4,7 +4,7 @@ import warnings
import torch
from sglang.srt.utils import get_bool_env_var
from sglang.srt.utils import get_bool_env_var, direct_register_custom_op
_USE_OPT_CAT = get_bool_env_var("SGLANG_USE_OPT_CAT")
......@@ -18,15 +18,50 @@ if _USE_OPT_CAT:
)
else:
ds_cat = None
def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# TODO: 单独注册有些问题
def ds_cat_wrapper(A: torch.Tensor,
B: torch.Tensor,
dim: int,
mode: int) -> torch.Tensor:
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
ds_cat(A, B, C, mode)
return C
def ds_cat_fake(A: torch.Tensor,
B: torch.Tensor,
dim: int,
mode: int) -> torch.Tensor:
# 使用标准cat作为fake实现
return torch.cat([A, B], dim=dim)
direct_register_custom_op(
op_name="ds_cat",
op_func=ds_cat_wrapper,
mutates_args=[], # 没有修改参数,只有返回值
fake_impl=ds_cat_fake
)
def concat_decode_opt(A: torch.Tensor, B: torch.Tensor, dim: int):
assert dim == 2, "tensor dim must be 3 and concat dim must be 2"
mode = 0
if dim!=0 :
ds_cat( A, B, C, mode)
return C
assert False, "not support"
\ No newline at end of file
if dim != 0:
return torch.ops.sglang.ds_cat(A, B, dim, mode)
assert False, "not support"
# def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
# assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# output_shape = list(A.shape)
# output_shape[dim] = A.shape[dim] + B.shape[dim]
# C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
# mode=0
# if dim!=0 :
# ds_cat(A, B, C, mode)
# return C
# assert False, "not support"
......@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu, direct_register_custom_op
from sglang.srt.utils.offloader import get_offloader
if TYPE_CHECKING:
......@@ -57,6 +57,105 @@ if _use_aiter:
logger = logging.getLogger(__name__)
#------ custom op for lightop
def m_grouped_w4a8_gemm_nt_masked_wrapper(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return m_grouped_w4a8_gemm_nt_masked(
(a0, a1),
(b0, b1),
d,
masked_m,
expected_m_per_group,
config={"MODE": 1000,}
)
def m_grouped_w4a8_gemm_nt_masked_fake(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return d
def m_grouped_w8a8_gemm_nt_masked_wrapper(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return m_grouped_w8a8_gemm_nt_masked(
(a0, a1),
(b0, b1),
d,
masked_m,
expected_m_per_group,
config={"MODE": 1000,}
)
def m_grouped_w8a8_gemm_nt_masked_fake(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return d
def fuse_silu_mul_quant_ep_wrapper(
input: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None,
num_local_tokens_tensor: Optional[torch.Tensor] = None,
topk:int=1,
expect_m:int=-1) -> tuple[torch.Tensor, torch.Tensor]:
return fuse_silu_mul_quant_ep(
input,
tokens_per_expert,
num_local_tokens_tensor,
topk,
expect_m
)
def fuse_silu_mul_quant_ep_fake(
input: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None,
num_local_tokens_tensor: Optional[torch.Tensor] = None,
topk:int=1,
expect_m:int=-1) -> tuple[torch.Tensor, torch.Tensor]:
E, T, H = input.shape
d = H // 2
output = torch.empty(E, T, d, dtype=torch.int8, device=input.device)
scales = torch.empty((E, T, 1),
device=input.device,
dtype=torch.float32)
return output, scales
direct_register_custom_op(
op_name="m_grouped_w4a8_gemm_nt_masked",
op_func=m_grouped_w4a8_gemm_nt_masked_wrapper,
mutates_args=[],
fake_impl=m_grouped_w4a8_gemm_nt_masked_fake
)
direct_register_custom_op(
op_name="m_grouped_w8a8_gemm_nt_masked",
op_func=m_grouped_w8a8_gemm_nt_masked_wrapper,
mutates_args=[],
fake_impl=m_grouped_w8a8_gemm_nt_masked_fake
)
direct_register_custom_op(
op_name="fuse_silu_mul_quant_ep",
op_func=fuse_silu_mul_quant_ep_wrapper,
mutates_args=[],
fake_impl=fuse_silu_mul_quant_ep_fake
)
#------
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
......@@ -815,23 +914,23 @@ class DeepEPMoE(EPMoE):
gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16)
# ---- first GEMM ----
m_grouped_w4a8_gemm_nt_masked(
(q_a1_all, q_a1_scale),
(w13_weight, w13_scales),
torch.ops.sglang.m_grouped_w4a8_gemm_nt_masked(
q_a1_all, q_a1_scale,
w13_weight, w13_scales,
gateup_output,
masked_m,
expected_m,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m)
q_a2_all, q_a2_scale = torch.ops.sglang.fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ----
n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
m_grouped_w4a8_gemm_nt_masked(
(q_a2_all, q_a2_scale),
(w2_weight, w2_scales),
torch.ops.sglang.m_grouped_w4a8_gemm_nt_masked(
q_a2_all, q_a2_scale,
w2_weight, w2_scales,
down_output,
masked_m,
expected_m,
......@@ -865,23 +964,23 @@ class DeepEPMoE(EPMoE):
gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16)
# ---- first GEMM ----
m_grouped_w8a8_gemm_nt_masked(
(q_a1_all, q_a1_scale),
(w13_weight, w13_scales),
torch.ops.sglang.m_grouped_w8a8_gemm_nt_masked(
q_a1_all, q_a1_scale,
w13_weight, w13_scales,
gateup_output,
masked_m,
expected_m,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m)
q_a2_all, q_a2_scale = torch.ops.sglang.fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ----
n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_masked(
(q_a2_all, q_a2_scale),
(w2_weight, w2_scales),
torch.ops.sglang.m_grouped_w8a8_gemm_nt_masked(
q_a2_all, q_a2_scale,
w2_weight, w2_scales,
down_output,
masked_m,
expected_m,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
......
......@@ -163,12 +163,14 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
)
layer.register_parameter("input_zero_point", input_zero_point)
@torch._dynamo.disable()
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
# TODO: add cutlass_scaled_mm_azp support
x_q, x_scale = per_token_quant_int8(x)
# TODO: fix with lmslim/lightop
return quant_ops.triton_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
)
......@@ -157,7 +157,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
)
layer.register_parameter("weight_scale", weight_scale)
@torch._dynamo.disable()
@torch._dynamo.disable() # TODO: 性能优化需要lmslim/lightop配合
def apply(
self,
layer: torch.nn.Module,
......
......@@ -214,7 +214,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
@torch._dynamo.disable()
@torch._dynamo.disable() # TODO: 性能优化需lmslim/lightop配合
def apply(
self,
layer: torch.nn.Module,
......@@ -253,6 +253,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
)
return StandardCombineInput(hidden_states=output)
@torch._dynamo.disable() # TODO: 性能优化需lmslim/lightop配合
def apply_with_shared_output(
self,
layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment