"vscode:/vscode.git/clone" did not exist on "0fecb2ddb9830b03e69bb8fe77a4596a8b7edf66"
Commit 3dc7dc6c authored by maxiao1's avatar maxiao1
Browse files

算子融合

parent 842b423a
...@@ -167,6 +167,14 @@ class Envs: ...@@ -167,6 +167,14 @@ class Envs:
# DCU Lightop # DCU Lightop
SGLANG_USE_LIGHTOP = EnvBool(False) SGLANG_USE_LIGHTOP = EnvBool(False)
# Fused
SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD = EnvBool(False)
SGLANG_USE_OPT_CAT = EnvBool(False)
SGLANG_USE_FUSED_RMS_QUANT = EnvBool(False)
SGLANG_USE_FUSED_SILU_MUL_QUANT = EnvBool(False)
# Quantization # Quantization
SGLANG_INT4_WEIGHT = EnvBool(False) SGLANG_INT4_WEIGHT = EnvBool(False)
SGLANG_CPU_QUANTIZATION = EnvBool(False) SGLANG_CPU_QUANTIZATION = EnvBool(False)
......
from __future__ import annotations
import warnings
import torch
from sglang.srt.utils import get_bool_env_var
_USE_OPT_CAT = get_bool_env_var("SGLANG_USE_OPT_CAT")
if _USE_OPT_CAT:
try:
from lightop import ds_cat # type: ignore
except ImportError: # pragma: no cover
ds_cat = None
warnings.warn(
"SGLANG_USE_OPT_CAT 已开启但无法导入 lightop.ds_cat,退回 torch.cat"
)
else:
ds_cat = None
def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
mode = 0
if dim!=0 :
ds_cat( A, B, C, mode)
return C
assert False, "not support"
\ No newline at end of file
...@@ -44,6 +44,18 @@ _is_hip = is_hip() ...@@ -44,6 +44,18 @@ _is_hip = is_hip()
_disable_hip_linear_quant = _is_hip and get_bool_env_var( _disable_hip_linear_quant = _is_hip and get_bool_env_var(
"SGLANG_ROCM_DISABLE_LINEARQUANT" "SGLANG_ROCM_DISABLE_LINEARQUANT"
) )
_use_fused_rms_quant = get_bool_env_var("SGLANG_USE_FUSED_RMS_QUANT")
_use_fused_silu_mul_quant = get_bool_env_var("SGLANG_USE_FUSED_SILU_MUL_QUANT")
if _use_fused_rms_quant:
try:
from lmslim.quantize.quant_ops import lm_faster_rmsquant
except Exception as e:
print(f"Error: Import fused rmsquant error: {e}")
if _use_fused_silu_mul_quant:
try:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
except Exception as e:
print(f"Error: Import fused silu_mul_quant error: {e}")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1358,7 +1370,7 @@ class RowParallelLinear(LinearBase): ...@@ -1358,7 +1370,7 @@ class RowParallelLinear(LinearBase):
# It does not support additional parameters. # It does not support additional parameters.
param.load_row_parallel_weight(loaded_weight) param.load_row_parallel_weight(loaded_weight)
def forward(self, input_, skip_all_reduce=False): def forward(self, input_, skip_all_reduce=False, use_fused_silu_mul_quant=False):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
...@@ -1372,9 +1384,19 @@ class RowParallelLinear(LinearBase): ...@@ -1372,9 +1384,19 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: if use_fused_silu_mul_quant:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) xq, xs = lm_fuse_silu_mul_quant(input_parallel)
sm.tag(output_parallel) silu_quant_args = [xq, xs]
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args
)
sm.tag(output_parallel)
else:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
......
...@@ -42,6 +42,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod ...@@ -42,6 +42,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.environ import envs
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
...@@ -58,6 +59,7 @@ if is_flashinfer_available(): ...@@ -58,6 +59,7 @@ if is_flashinfer_available():
_is_hip = is_hip() _is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_user_lightop_moe_sum_mul_add = get_bool_env_var("SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD")
# Try to import FP4 TRTLLM function if flashinfer is available # Try to import FP4 TRTLLM function if flashinfer is available
...@@ -221,6 +223,7 @@ class FusedMoE(torch.nn.Module): ...@@ -221,6 +223,7 @@ class FusedMoE(torch.nn.Module):
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.use_presharded_weights = use_presharded_weights self.use_presharded_weights = use_presharded_weights
# self.global_num_experts = self.num_experts
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
...@@ -877,9 +880,21 @@ class FusedMoE(torch.nn.Module): ...@@ -877,9 +880,21 @@ class FusedMoE(torch.nn.Module):
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded." f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
) )
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput = None, shared_output: torch.Tensor = None, **kwargs):
origin_hidden_states_dim = hidden_states.shape[-1] origin_hidden_states_dim = hidden_states.shape[-1]
assert self.quant_method is not None assert self.quant_method is not None
if _user_lightop_moe_sum_mul_add:
final_hidden_states = self.quant_method.apply_with_shared_output(
layer=self,
x=hidden_states,
activation=getattr(self, 'moe_runner_config', None) and self.moe_runner_config.activation or "silu",
shared_output=shared_output,
topk_output=topk_output,
)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
dispatch_output = self.dispatcher.dispatch( dispatch_output = self.dispatcher.dispatch(
hidden_states=hidden_states, topk_output=topk_output hidden_states=hidden_states, topk_output=topk_output
......
...@@ -19,6 +19,9 @@ from vllm.utils import W8a8GetCacheJSON ...@@ -19,6 +19,9 @@ from vllm.utils import W8a8GetCacheJSON
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
import os import os
from sglang.srt.utils import get_bool_env_var
_use_fused_rms_quant = get_bool_env_var("SGLANG_USE_FUSED_RMS_QUANT")
_use_fused_silu_mul_quant = get_bool_env_var("SGLANG_USE_FUSED_SILU_MUL_QUANT")
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
""" """
...@@ -163,13 +166,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -163,13 +166,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
input_quant_args: Optional[list[torch.Tensor]] = None, input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None silu_quant_args: Optional[list[torch.Tensor]] = None
): ):
# if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None: if _use_fused_rms_quant and input_quant_args is not None:
# assert len(input_quant_args) == 2 assert len(input_quant_args) == 2
# x_q, x_scale = input_quant_args x_q, x_scale = input_quant_args
# elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None: elif _use_fused_silu_mul_quant and silu_quant_args is not None:
# x_q, x_scale = silu_quant_args x_q, x_scale = silu_quant_args
# else: else:
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1: if self.w8a8_strategy==1:
m=x_q.shape[0] m=x_q.shape[0]
......
...@@ -252,6 +252,39 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -252,6 +252,39 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe=False, use_nn_moe=False,
) )
return StandardCombineInput(hidden_states=output) return StandardCombineInput(hidden_states=output)
def apply_with_shared_output(
self,
layer: torch.nn.Module,
x: torch.Tensor,
activation: str = "silu",
shared_output: Optional[torch.Tensor] = None,
topk_output=None,
) -> torch.Tensor:
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
global_num_experts=layer.moe_runner_config.num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
shared_output=shared_output,
)
# def _apply( # def _apply(
# self, # self,
# layer: torch.nn.Module, # layer: torch.nn.Module,
...@@ -317,9 +350,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -317,9 +350,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# a2_scale=layer.w2_input_scale, # a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe, # use_nn_moe=use_nn_moe,
# ) # )
#
def apply_ep(self, def apply_ep(self,
x: torch.Tensor, x: torch.Tensor,
...@@ -368,4 +398,4 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -368,4 +398,4 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
#config_select_bs=config_select_bs, #config_select_bs=config_select_bs,
#q_scales=scales #q_scales=scales
) )
\ No newline at end of file
...@@ -141,6 +141,7 @@ from sglang.srt.utils import ( ...@@ -141,6 +141,7 @@ from sglang.srt.utils import (
make_layers, make_layers,
use_intel_amx_backend, use_intel_amx_backend,
) )
from sglang.srt.layers.attention.lightop_concat import concat_decode_opt
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -151,8 +152,10 @@ _is_cpu_amx_available = cpu_has_amx_support() ...@@ -151,8 +152,10 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_device_sm = get_device_sm() _device_sm = get_device_sm()
_is_gfx95_supported = is_gfx95_supported() _is_gfx95_supported = is_gfx95_supported()
_user_lightop_moe_sum_mul_add = get_bool_env_var("SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD")
_use_fused_silu_mul_quant = get_bool_env_var("SGLANG_USE_FUSED_SILU_MUL_QUANT")
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported _use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
_use_opt_cat_decode = get_bool_env_var("SGLANG_USE_OPT_CAT")
if _use_aiter_gfx95: if _use_aiter_gfx95:
from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
...@@ -456,10 +459,13 @@ class DeepseekV2MLP(nn.Module): ...@@ -456,10 +459,13 @@ class DeepseekV2MLP(nn.Module):
x = (x, None, y) x = (x, None, y)
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) if _use_fused_silu_mul_quant:
x, _ = self.down_proj( x, _ = self.down_proj(gate_up, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter, use_fused_silu_mul_quant=True)
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter else:
) x = self.act_fn(gate_up)
x, _ = self.down_proj(
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
)
return x return x
...@@ -757,49 +763,58 @@ class DeepseekV2MoE(nn.Module): ...@@ -757,49 +763,58 @@ class DeepseekV2MoE(nn.Module):
self.shared_experts.gate_up_proj self.shared_experts.gate_up_proj
): ):
return self.forward_cpu(hidden_states, should_allreduce_fusion) return self.forward_cpu(hidden_states, should_allreduce_fusion)
if _user_lightop_moe_sum_mul_add:
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
if not self._fuse_shared_experts_inside_sbo: if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts( shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator hidden_states, gemm_output_zero_allocator
) )
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states, gemm_output_zero_allocator) router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits) topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output, shared_output=shared_output)
else: else:
shared_output = None if hidden_states.shape[0] > 0:
topk_output = self.topk.empty_topk_output(hidden_states.device) if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(
if self._fuse_shared_experts_inside_sbo: hidden_states, gemm_output_zero_allocator
shared_output = None )
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits)
else:
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
def _forward_shared_experts_and_put_results(): if self._fuse_shared_experts_inside_sbo:
nonlocal shared_output shared_output = None
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
final_hidden_states = self.experts( def _forward_shared_experts_and_put_results():
hidden_states, nonlocal shared_output
topk_output, shared_output = self._forward_shared_experts(
**( hidden_states, gemm_output_zero_allocator
dict( )
forward_shared_experts=_forward_shared_experts_and_put_results, final_hidden_states = self.experts(
alt_stream=self.alt_stream, hidden_states,
) topk_output,
if self._fuse_shared_experts_inside_sbo **(
else {} dict(
), forward_shared_experts=_forward_shared_experts_and_put_results,
) alt_stream=self.alt_stream,
if not _is_cuda and not _use_aiter: )
# fused in biased_grouped_topk so we can skip here if self._fuse_shared_experts_inside_sbo
final_hidden_states *= self.routed_scaling_factor else {}
if shared_output is not None: ),
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: )
final_hidden_states_out = torch.empty_like(final_hidden_states) if not _is_cuda and not _use_aiter:
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) # fused in biased_grouped_topk so we can skip here
final_hidden_states = final_hidden_states_out final_hidden_states *= self.routed_scaling_factor
sm.tag(final_hidden_states) if shared_output is not None:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if ( if (
self.tp_size > 1 self.tp_size > 1
and not should_allreduce_fusion and not should_allreduce_fusion
...@@ -1696,7 +1711,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1696,7 +1711,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self.rotary_emb.is_neox_style, self.rotary_emb.is_neox_style,
) )
else: else:
q = torch.cat([q_nope_out, q_pe], dim=-1) if _use_opt_cat_decode and q_nope_out.shape[0] < 1024:
q = concat_decode_opt(q_nope_out, q_pe, dim=2)
else:
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1) k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa( attn_output = self.attn_mqa(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment