Commit 8da47f19 authored by renzhc's avatar renzhc
Browse files

支持w8a8 compile,注册自定义算子解决部分断图问题

parent bd63af06
...@@ -362,7 +362,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -362,7 +362,7 @@ class DCUMLABackend(AttentionBackend):
) )
return o return o
@torch._dynamo.disable() @torch._dynamo.disable() # NOTE: FP8 cache decode不支持compile
def forward_decode( def forward_decode(
self, self,
q: torch.Tensor, q: torch.Tensor,
...@@ -417,7 +417,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -417,7 +417,7 @@ class DCUMLABackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@torch._dynamo.disable() # NOTE: untested @torch._dynamo.disable()
def forward_extend( def forward_extend(
self, self,
q: torch.Tensor, q: torch.Tensor,
......
...@@ -4,7 +4,7 @@ import warnings ...@@ -4,7 +4,7 @@ import warnings
import torch import torch
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var, direct_register_custom_op
_USE_OPT_CAT = get_bool_env_var("SGLANG_USE_OPT_CAT") _USE_OPT_CAT = get_bool_env_var("SGLANG_USE_OPT_CAT")
...@@ -18,15 +18,50 @@ if _USE_OPT_CAT: ...@@ -18,15 +18,50 @@ if _USE_OPT_CAT:
) )
else: else:
ds_cat = None ds_cat = None
def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# TODO: 单独注册有些问题
def ds_cat_wrapper(A: torch.Tensor,
B: torch.Tensor,
dim: int,
mode: int) -> torch.Tensor:
output_shape = list(A.shape) output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim] output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype) C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
ds_cat(A, B, C, mode)
return C
def ds_cat_fake(A: torch.Tensor,
B: torch.Tensor,
dim: int,
mode: int) -> torch.Tensor:
# 使用标准cat作为fake实现
return torch.cat([A, B], dim=dim)
direct_register_custom_op(
op_name="ds_cat",
op_func=ds_cat_wrapper,
mutates_args=[], # 没有修改参数,只有返回值
fake_impl=ds_cat_fake
)
def concat_decode_opt(A: torch.Tensor, B: torch.Tensor, dim: int):
assert dim == 2, "tensor dim must be 3 and concat dim must be 2"
mode = 0 mode = 0
if dim!=0 : if dim != 0:
ds_cat( A, B, C, mode) return torch.ops.sglang.ds_cat(A, B, dim, mode)
return C assert False, "not support"
assert False, "not support"
\ No newline at end of file # def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
# assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# output_shape = list(A.shape)
# output_shape[dim] = A.shape[dim] + B.shape[dim]
# C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
# mode=0
# if dim!=0 :
# ds_cat(A, B, C, mode)
# return C
# assert False, "not support"
...@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
) )
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu, direct_register_custom_op
from sglang.srt.utils.offloader import get_offloader from sglang.srt.utils.offloader import get_offloader
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -57,6 +57,105 @@ if _use_aiter: ...@@ -57,6 +57,105 @@ if _use_aiter:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
#------ custom op for lightop
def m_grouped_w4a8_gemm_nt_masked_wrapper(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return m_grouped_w4a8_gemm_nt_masked(
(a0, a1),
(b0, b1),
d,
masked_m,
expected_m_per_group,
config={"MODE": 1000,}
)
def m_grouped_w4a8_gemm_nt_masked_fake(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return d
def m_grouped_w8a8_gemm_nt_masked_wrapper(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return m_grouped_w8a8_gemm_nt_masked(
(a0, a1),
(b0, b1),
d,
masked_m,
expected_m_per_group,
config={"MODE": 1000,}
)
def m_grouped_w8a8_gemm_nt_masked_fake(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return d
def fuse_silu_mul_quant_ep_wrapper(
input: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None,
num_local_tokens_tensor: Optional[torch.Tensor] = None,
topk:int=1,
expect_m:int=-1) -> tuple[torch.Tensor, torch.Tensor]:
return fuse_silu_mul_quant_ep(
input,
tokens_per_expert,
num_local_tokens_tensor,
topk,
expect_m
)
def fuse_silu_mul_quant_ep_fake(
input: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None,
num_local_tokens_tensor: Optional[torch.Tensor] = None,
topk:int=1,
expect_m:int=-1) -> tuple[torch.Tensor, torch.Tensor]:
E, T, H = input.shape
d = H // 2
output = torch.empty(E, T, d, dtype=torch.int8, device=input.device)
scales = torch.empty((E, T, 1),
device=input.device,
dtype=torch.float32)
return output, scales
direct_register_custom_op(
op_name="m_grouped_w4a8_gemm_nt_masked",
op_func=m_grouped_w4a8_gemm_nt_masked_wrapper,
mutates_args=[],
fake_impl=m_grouped_w4a8_gemm_nt_masked_fake
)
direct_register_custom_op(
op_name="m_grouped_w8a8_gemm_nt_masked",
op_func=m_grouped_w8a8_gemm_nt_masked_wrapper,
mutates_args=[],
fake_impl=m_grouped_w8a8_gemm_nt_masked_fake
)
direct_register_custom_op(
op_name="fuse_silu_mul_quant_ep",
op_func=fuse_silu_mul_quant_ep_wrapper,
mutates_args=[],
fake_impl=fuse_silu_mul_quant_ep_fake
)
#------
# TODO(kaixih@nvidia): ideally we should merge this logic into # TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale. # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
...@@ -815,23 +914,23 @@ class DeepEPMoE(EPMoE): ...@@ -815,23 +914,23 @@ class DeepEPMoE(EPMoE):
gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16) gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16)
# ---- first GEMM ---- # ---- first GEMM ----
m_grouped_w4a8_gemm_nt_masked( torch.ops.sglang.m_grouped_w4a8_gemm_nt_masked(
(q_a1_all, q_a1_scale), q_a1_all, q_a1_scale,
(w13_weight, w13_scales), w13_weight, w13_scales,
gateup_output, gateup_output,
masked_m, masked_m,
expected_m, expected_m,
) )
q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m) q_a2_all, q_a2_scale = torch.ops.sglang.fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ---- # ---- second GEMM ----
n2 = w2_scales.size(1) n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16) down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
m_grouped_w4a8_gemm_nt_masked( torch.ops.sglang.m_grouped_w4a8_gemm_nt_masked(
(q_a2_all, q_a2_scale), q_a2_all, q_a2_scale,
(w2_weight, w2_scales), w2_weight, w2_scales,
down_output, down_output,
masked_m, masked_m,
expected_m, expected_m,
...@@ -865,23 +964,23 @@ class DeepEPMoE(EPMoE): ...@@ -865,23 +964,23 @@ class DeepEPMoE(EPMoE):
gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16) gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16)
# ---- first GEMM ---- # ---- first GEMM ----
m_grouped_w8a8_gemm_nt_masked( torch.ops.sglang.m_grouped_w8a8_gemm_nt_masked(
(q_a1_all, q_a1_scale), q_a1_all, q_a1_scale,
(w13_weight, w13_scales), w13_weight, w13_scales,
gateup_output, gateup_output,
masked_m, masked_m,
expected_m, expected_m,
) )
q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m) q_a2_all, q_a2_scale = torch.ops.sglang.fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ---- # ---- second GEMM ----
n2 = w2_scales.size(1) n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16) down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_masked( torch.ops.sglang.m_grouped_w8a8_gemm_nt_masked(
(q_a2_all, q_a2_scale), q_a2_all, q_a2_scale,
(w2_weight, w2_scales), w2_weight, w2_scales,
down_output, down_output,
masked_m, masked_m,
expected_m, expected_m,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations from __future__ import annotations
...@@ -15,6 +16,7 @@ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase ...@@ -15,6 +16,7 @@ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
try: try:
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception: except Exception:
...@@ -77,7 +79,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -77,7 +79,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"weights") "weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations") "input_activations")
self.use_deepep = True self.use_deepep = get_moe_a2a_backend().is_deepep()
per_channel = ( per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN) and self.input_quant.strategy == QuantizationStrategy.TOKEN)
......
...@@ -163,12 +163,14 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -163,12 +163,14 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
) )
layer.register_parameter("input_zero_point", input_zero_point) layer.register_parameter("input_zero_point", input_zero_point)
@torch._dynamo.disable()
def apply_weights( def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: add cutlass_scaled_mm_azp support # TODO: add cutlass_scaled_mm_azp support
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
# TODO: fix with lmslim/lightop
return quant_ops.triton_scaled_mm( return quant_ops.triton_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
) )
...@@ -157,7 +157,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -157,7 +157,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
) )
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
@torch._dynamo.disable() @torch._dynamo.disable() # TODO: 性能优化需要lmslim/lightop配合
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -214,7 +214,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -214,7 +214,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self.moe_runner_config = moe_runner_config self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
@torch._dynamo.disable() @torch._dynamo.disable() # TODO: 性能优化需lmslim/lightop配合
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -253,6 +253,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -253,6 +253,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
) )
return StandardCombineInput(hidden_states=output) return StandardCombineInput(hidden_states=output)
@torch._dynamo.disable() # TODO: 性能优化需lmslim/lightop配合
def apply_with_shared_output( def apply_with_shared_output(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment