"test/vscode:/vscode.git/clone" did not exist on "efaa27c19e0af0ea144f67b96d69acc8bb756623"
Unverified Commit d58e3544 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

simplify the control logic for using shared experts fusion (#5504)

parent bf86c5e9
...@@ -136,6 +136,7 @@ class EPMoE(torch.nn.Module): ...@@ -136,6 +136,7 @@ class EPMoE(torch.nn.Module):
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
): ):
super().__init__() super().__init__()
...@@ -164,6 +165,7 @@ class EPMoE(torch.nn.Module): ...@@ -164,6 +165,7 @@ class EPMoE(torch.nn.Module):
self.correction_bias = correction_bias self.correction_bias = correction_bias
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
self.activation = activation self.activation = activation
self.routed_scaling_factor = routed_scaling_factor
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
...@@ -215,6 +217,7 @@ class EPMoE(torch.nn.Module): ...@@ -215,6 +217,7 @@ class EPMoE(torch.nn.Module):
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
) )
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
......
...@@ -26,6 +26,7 @@ def fused_moe_forward_native( ...@@ -26,6 +26,7 @@ def fused_moe_forward_native(
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if apply_router_weight_on_input: if apply_router_weight_on_input:
...@@ -41,6 +42,7 @@ def fused_moe_forward_native( ...@@ -41,6 +42,7 @@ def fused_moe_forward_native(
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
torch_native=True, torch_native=True,
) )
...@@ -71,6 +73,7 @@ def moe_forward_native( ...@@ -71,6 +73,7 @@ def moe_forward_native(
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
...@@ -86,6 +89,7 @@ def moe_forward_native( ...@@ -86,6 +89,7 @@ def moe_forward_native(
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
torch_native=True, torch_native=True,
routed_scaling_factor=routed_scaling_factor,
) )
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589 # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
......
...@@ -1547,6 +1547,7 @@ def fused_moe( ...@@ -1547,6 +1547,7 @@ def fused_moe(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1601,6 +1602,7 @@ def fused_moe( ...@@ -1601,6 +1602,7 @@ def fused_moe(
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
routed_scaling_factor=routed_scaling_factor,
) )
return fused_experts( return fused_experts(
......
...@@ -131,6 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -131,6 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward( return self.forward(
x=x, x=x,
...@@ -147,6 +148,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -147,6 +148,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace, inplace=inplace,
no_combine=no_combine, no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
def forward_cuda( def forward_cuda(
...@@ -165,6 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -165,6 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
...@@ -176,6 +179,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -176,6 +179,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
) )
if _is_hip and get_bool_env_var("CK_MOE"): if _is_hip and get_bool_env_var("CK_MOE"):
...@@ -284,6 +288,7 @@ class FusedMoE(torch.nn.Module): ...@@ -284,6 +288,7 @@ class FusedMoE(torch.nn.Module):
use_presharded_weights: bool = False, use_presharded_weights: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
): ):
super().__init__() super().__init__()
...@@ -293,6 +298,7 @@ class FusedMoE(torch.nn.Module): ...@@ -293,6 +298,7 @@ class FusedMoE(torch.nn.Module):
self.tp_size = ( self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size() tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
) )
self.routed_scaling_factor = routed_scaling_factor
self.top_k = top_k self.top_k = top_k
self.num_experts = num_experts self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0 assert intermediate_size % self.tp_size == 0
...@@ -637,6 +643,7 @@ class FusedMoE(torch.nn.Module): ...@@ -637,6 +643,7 @@ class FusedMoE(torch.nn.Module):
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
activation=self.activation, activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
) )
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
......
...@@ -98,6 +98,7 @@ def grouped_topk( ...@@ -98,6 +98,7 @@ def grouped_topk(
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
n_share_experts_fusion: int = 0, n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
...@@ -127,9 +128,7 @@ def grouped_topk( ...@@ -127,9 +128,7 @@ def grouped_topk(
dtype=topk_ids.dtype, dtype=topk_ids.dtype,
device=topk_ids.device, device=topk_ids.device,
) )
topk_weights[:, -1] = ( topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
topk_weights[:, :-1].sum(dim=-1) / 2.5
) # 2.5 is the routed_scaling_factor.
if renormalize: if renormalize:
topk_weights_sum = ( topk_weights_sum = (
...@@ -151,6 +150,7 @@ def biased_grouped_topk_impl( ...@@ -151,6 +150,7 @@ def biased_grouped_topk_impl(
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
n_share_experts_fusion: int = 0, n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
...@@ -187,9 +187,7 @@ def biased_grouped_topk_impl( ...@@ -187,9 +187,7 @@ def biased_grouped_topk_impl(
dtype=topk_ids.dtype, dtype=topk_ids.dtype,
device=topk_ids.device, device=topk_ids.device,
) )
topk_weights[:, -1] = ( topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
topk_weights[:, :-1].sum(dim=-1) / 2.5
) # 2.5 is the routed_scaling_factor.
if renormalize: if renormalize:
topk_weights_sum = ( topk_weights_sum = (
...@@ -216,13 +214,16 @@ def biased_grouped_topk( ...@@ -216,13 +214,16 @@ def biased_grouped_topk(
topk_group: int = 0, topk_group: int = 0,
compiled: bool = True, compiled: bool = True,
n_share_experts_fusion: int = 0, n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
): ):
assert (
routed_scaling_factor is not None
), "routed_scaling_factor is required for biased_grouped_topk"
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now. # TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
if ( if (
_is_cuda _is_cuda
and gating_output.shape[1] // num_expert_group and gating_output.shape[1] // num_expert_group
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion. <= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
and n_share_experts_fusion == 0
and is_power_of_two(correction_bias.shape[0]) and is_power_of_two(correction_bias.shape[0])
): ):
return moe_fused_gate( return moe_fused_gate(
...@@ -231,6 +232,8 @@ def biased_grouped_topk( ...@@ -231,6 +232,8 @@ def biased_grouped_topk(
num_expert_group, num_expert_group,
topk_group, topk_group,
topk, topk,
n_share_experts_fusion,
routed_scaling_factor,
) )
else: else:
biased_grouped_topk_fn = ( biased_grouped_topk_fn = (
...@@ -249,6 +252,7 @@ def biased_grouped_topk( ...@@ -249,6 +252,7 @@ def biased_grouped_topk(
num_expert_group, num_expert_group,
topk_group, topk_group,
n_share_experts_fusion=n_share_experts_fusion, n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
) )
...@@ -263,10 +267,9 @@ def select_experts( ...@@ -263,10 +267,9 @@ def select_experts(
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False, torch_native: bool = False,
routed_scaling_factor: Optional[float] = None,
): ):
n_share_experts_fusion = 0 n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
if global_server_args_dict["n_share_experts_fusion"] is not None:
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# DeekSeek V2/V3/R1 serices models uses grouped_top_k # DeekSeek V2/V3/R1 serices models uses grouped_top_k
if use_grouped_topk: if use_grouped_topk:
assert topk_group is not None assert topk_group is not None
...@@ -280,6 +283,7 @@ def select_experts( ...@@ -280,6 +283,7 @@ def select_experts(
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion, n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
) )
else: else:
topk_weights, topk_ids = biased_grouped_topk( topk_weights, topk_ids = biased_grouped_topk(
...@@ -291,6 +295,7 @@ def select_experts( ...@@ -291,6 +295,7 @@ def select_experts(
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion, n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
) )
elif torch_native and custom_routing_function is None: elif torch_native and custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native( topk_weights, topk_ids = fused_topk_native(
......
...@@ -290,6 +290,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): ...@@ -290,6 +290,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
): ):
assert activation == "silu" assert activation == "silu"
assert inplace and not no_combine assert inplace and not no_combine
......
...@@ -373,6 +373,7 @@ class BlockInt8MoEMethod: ...@@ -373,6 +373,7 @@ class BlockInt8MoEMethod:
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -388,6 +389,7 @@ class BlockInt8MoEMethod: ...@@ -388,6 +389,7 @@ class BlockInt8MoEMethod:
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
) )
# Expert fusion with INT8 quantization # Expert fusion with INT8 quantization
......
...@@ -283,6 +283,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -283,6 +283,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import fused_experts from sglang.srt.layers.moe.fused_moe_triton import fused_experts
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -297,6 +298,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -297,6 +298,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
) )
return fused_experts( return fused_experts(
...@@ -633,6 +635,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -633,6 +635,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
scoring_func: str = "softmax", scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -653,6 +656,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -653,6 +656,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
) )
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
......
...@@ -892,6 +892,7 @@ class Fp8MoEMethod: ...@@ -892,6 +892,7 @@ class Fp8MoEMethod:
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -907,6 +908,7 @@ class Fp8MoEMethod: ...@@ -907,6 +908,7 @@ class Fp8MoEMethod:
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
) )
if _is_hip: if _is_hip:
......
...@@ -347,6 +347,7 @@ class MoeWNA16Method: ...@@ -347,6 +347,7 @@ class MoeWNA16Method:
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# avoid circular import # avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -363,6 +364,7 @@ class MoeWNA16Method: ...@@ -363,6 +364,7 @@ class MoeWNA16Method:
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
) )
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
......
...@@ -294,6 +294,7 @@ class W8A8FP8MoEMethod: ...@@ -294,6 +294,7 @@ class W8A8FP8MoEMethod:
activation: str = "silu", activation: str = "silu",
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -309,6 +310,7 @@ class W8A8FP8MoEMethod: ...@@ -309,6 +310,7 @@ class W8A8FP8MoEMethod:
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
) )
return fused_experts( return fused_experts(
......
...@@ -231,6 +231,7 @@ class W8A8Int8MoEMethod: ...@@ -231,6 +231,7 @@ class W8A8Int8MoEMethod:
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -246,6 +247,7 @@ class W8A8Int8MoEMethod: ...@@ -246,6 +247,7 @@ class W8A8Int8MoEMethod:
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
) )
return fused_experts( return fused_experts(
......
...@@ -81,7 +81,6 @@ global_server_args_dict = { ...@@ -81,7 +81,6 @@ global_server_args_dict = {
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size, "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"chunked_prefill_size": ServerArgs.chunked_prefill_size, "chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache, "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
} }
......
...@@ -163,7 +163,6 @@ class ModelRunner: ...@@ -163,7 +163,6 @@ class ModelRunner:
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion, "n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache, "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
"use_mla_backend": self.use_mla_backend, "use_mla_backend": self.use_mla_backend,
} }
......
...@@ -189,11 +189,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -189,11 +189,7 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts self.n_shared_experts = config.n_shared_experts
self.n_share_experts_fusion = ( self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
global_server_args_dict["n_share_experts_fusion"]
if global_server_args_dict["n_share_experts_fusion"] is not None
else 0
)
if self.tp_size > config.n_routed_experts: if self.tp_size > config.n_routed_experts:
raise ValueError( raise ValueError(
...@@ -226,6 +222,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -226,6 +222,7 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=config.n_group, num_expert_group=config.n_group,
topk_group=config.topk_group, topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias, correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**( **(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
...@@ -334,6 +331,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -334,6 +331,7 @@ class DeepseekV2MoE(nn.Module):
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
) )
if self.ep_size > 1: if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
...@@ -374,7 +372,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -374,7 +372,7 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states return final_hidden_states
def _forward_shared_experts(self, hidden_states): def _forward_shared_experts(self, hidden_states):
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0: if self.n_share_experts_fusion == 0:
return self.shared_experts(hidden_states) return self.shared_experts(hidden_states)
else: else:
return None return None
...@@ -1346,24 +1344,21 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1346,24 +1344,21 @@ class DeepseekV2ForCausalLM(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# Only Deepseek V3/R1 can use shared experts fusion optimization now. if self.n_share_experts_fusion > 0:
if ( # Only Deepseek V3/R1 can use shared experts fusion optimization now.
global_server_args_dict.get("disable_shared_experts_fusion", False) if (
or self.config.architectures[0] != "DeepseekV3ForCausalLM" self.config.architectures[0] != "DeepseekV3ForCausalLM"
or self.config.n_routed_experts != 256 or self.config.n_routed_experts != 256
or self.config.routed_scaling_factor != 2.5 ):
): self.n_share_experts_fusion = 0
self.n_share_experts_fusion = None global_server_args_dict["n_share_experts_fusion"] = 0
global_server_args_dict["n_share_experts_fusion"] = None logger.info(
logger.info( "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled." )
) else:
elif self.n_share_experts_fusion is None: assert (
global_server_args_dict["n_share_experts_fusion"] = self.tp_size self.n_share_experts_fusion == self.tp_size
self.n_share_experts_fusion = self.tp_size ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
logger.info(
f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
)
self.model = DeepseekV2Model( self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
...@@ -1484,7 +1479,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1484,7 +1479,7 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0: if self.n_share_experts_fusion > 0:
weights_list = list(weights) weights_list = list(weights)
weights_dict = dict(weights_list) weights_dict = dict(weights_list)
if self.quant_config.get_name() == "w8a8_int8": if self.quant_config.get_name() == "w8a8_int8":
...@@ -1543,12 +1538,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1543,12 +1538,7 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
+ (
self.n_share_experts_fusion
if self.n_share_experts_fusion is not None
else 0
),
) )
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
......
...@@ -183,7 +183,6 @@ class ServerArgs: ...@@ -183,7 +183,6 @@ class ServerArgs:
warmups: Optional[str] = None warmups: Optional[str] = None
moe_dense_tp_size: Optional[int] = None moe_dense_tp_size: Optional[int] = None
n_share_experts_fusion: int = 0 n_share_experts_fusion: int = 0
disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False disable_fast_image_processor: bool = False
...@@ -229,9 +228,6 @@ class ServerArgs: ...@@ -229,9 +228,6 @@ class ServerArgs:
# GPU memory is not known yet or no GPU is available. # GPU memory is not known yet or no GPU is available.
gpu_mem = None gpu_mem = None
if is_hip():
self.disable_shared_experts_fusion = True
# Set mem fraction static, which depends on the tensor parallelism size # Set mem fraction static, which depends on the tensor parallelism size
if self.mem_fraction_static is None: if self.mem_fraction_static is None:
if self.tp_size >= 16: if self.tp_size >= 16:
...@@ -1126,13 +1122,8 @@ class ServerArgs: ...@@ -1126,13 +1122,8 @@ class ServerArgs:
"--n-share-experts-fusion", "--n-share-experts-fusion",
type=int, type=int,
default=0, default=0,
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
"we use tp_size by default.", "set it to tp_size can get best optimized performace.",
)
parser.add_argument(
"--disable-shared-experts-fusion",
action="store_true",
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
) )
parser.add_argument( parser.add_argument(
"--disable-chunked-prefix-cache", "--disable-chunked-prefix-cache",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment