Unverified Commit d58e3544 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

simplify the control logic for using shared experts fusion (#5504)

parent bf86c5e9
......@@ -136,6 +136,7 @@ class EPMoE(torch.nn.Module):
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
):
super().__init__()
......@@ -164,6 +165,7 @@ class EPMoE(torch.nn.Module):
self.correction_bias = correction_bias
self.custom_routing_function = custom_routing_function
self.activation = activation
self.routed_scaling_factor = routed_scaling_factor
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
......@@ -215,6 +217,7 @@ class EPMoE(torch.nn.Module):
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
)
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
......
......@@ -26,6 +26,7 @@ def fused_moe_forward_native(
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
if apply_router_weight_on_input:
......@@ -41,6 +42,7 @@ def fused_moe_forward_native(
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
torch_native=True,
)
......@@ -71,6 +73,7 @@ def moe_forward_native(
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
......@@ -86,6 +89,7 @@ def moe_forward_native(
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
routed_scaling_factor=routed_scaling_factor,
)
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
......
......@@ -1547,6 +1547,7 @@ def fused_moe(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -1601,6 +1602,7 @@ def fused_moe(
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
......
......@@ -131,6 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return self.forward(
x=x,
......@@ -147,6 +148,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def forward_cuda(
......@@ -165,6 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
......@@ -176,6 +179,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if _is_hip and get_bool_env_var("CK_MOE"):
......@@ -284,6 +288,7 @@ class FusedMoE(torch.nn.Module):
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
):
super().__init__()
......@@ -293,6 +298,7 @@ class FusedMoE(torch.nn.Module):
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
self.routed_scaling_factor = routed_scaling_factor
self.top_k = top_k
self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
......@@ -637,6 +643,7 @@ class FusedMoE(torch.nn.Module):
correction_bias=self.correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
)
if self.reduce_results and self.tp_size > 1:
......
......@@ -98,6 +98,7 @@ def grouped_topk(
num_expert_group: int = 0,
topk_group: int = 0,
n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
......@@ -127,9 +128,7 @@ def grouped_topk(
dtype=topk_ids.dtype,
device=topk_ids.device,
)
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / 2.5
) # 2.5 is the routed_scaling_factor.
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
if renormalize:
topk_weights_sum = (
......@@ -151,6 +150,7 @@ def biased_grouped_topk_impl(
num_expert_group: int = 0,
topk_group: int = 0,
n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
......@@ -187,9 +187,7 @@ def biased_grouped_topk_impl(
dtype=topk_ids.dtype,
device=topk_ids.device,
)
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / 2.5
) # 2.5 is the routed_scaling_factor.
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
if renormalize:
topk_weights_sum = (
......@@ -216,13 +214,16 @@ def biased_grouped_topk(
topk_group: int = 0,
compiled: bool = True,
n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
):
assert (
routed_scaling_factor is not None
), "routed_scaling_factor is required for biased_grouped_topk"
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
if (
_is_cuda
and gating_output.shape[1] // num_expert_group
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
and n_share_experts_fusion == 0
and is_power_of_two(correction_bias.shape[0])
):
return moe_fused_gate(
......@@ -231,6 +232,8 @@ def biased_grouped_topk(
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
)
else:
biased_grouped_topk_fn = (
......@@ -249,6 +252,7 @@ def biased_grouped_topk(
num_expert_group,
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -263,10 +267,9 @@ def select_experts(
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
routed_scaling_factor: Optional[float] = None,
):
n_share_experts_fusion = 0
if global_server_args_dict["n_share_experts_fusion"] is not None:
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
......@@ -280,6 +283,7 @@ def select_experts(
num_expert_group=num_expert_group,
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
)
else:
topk_weights, topk_ids = biased_grouped_topk(
......@@ -291,6 +295,7 @@ def select_experts(
num_expert_group=num_expert_group,
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
)
elif torch_native and custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native(
......
......@@ -290,6 +290,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
):
assert activation == "silu"
assert inplace and not no_combine
......
......@@ -373,6 +373,7 @@ class BlockInt8MoEMethod:
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
......@@ -388,6 +389,7 @@ class BlockInt8MoEMethod:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
# Expert fusion with INT8 quantization
......
......@@ -283,6 +283,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
inplace: bool = True,
no_combine: bool = False,
apply_router_weight_on_input: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
from sglang.srt.layers.moe.topk import select_experts
......@@ -297,6 +298,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
......@@ -633,6 +635,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.topk import select_experts
......@@ -653,6 +656,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return torch.ops.vllm.fused_marlin_moe(
......
......@@ -892,6 +892,7 @@ class Fp8MoEMethod:
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
......@@ -907,6 +908,7 @@ class Fp8MoEMethod:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if _is_hip:
......
......@@ -347,6 +347,7 @@ class MoeWNA16Method:
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
# avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
......@@ -363,6 +364,7 @@ class MoeWNA16Method:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
weight_bits = self.quant_config.weight_bits
......
......@@ -294,6 +294,7 @@ class W8A8FP8MoEMethod:
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
......@@ -309,6 +310,7 @@ class W8A8FP8MoEMethod:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
......
......@@ -231,6 +231,7 @@ class W8A8Int8MoEMethod:
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
......@@ -246,6 +247,7 @@ class W8A8Int8MoEMethod:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
......
......@@ -81,7 +81,6 @@ global_server_args_dict = {
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
}
......
......@@ -163,7 +163,6 @@ class ModelRunner:
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
"use_mla_backend": self.use_mla_backend,
}
......
......@@ -189,11 +189,7 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.n_share_experts_fusion = (
global_server_args_dict["n_share_experts_fusion"]
if global_server_args_dict["n_share_experts_fusion"] is not None
else 0
)
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
if self.tp_size > config.n_routed_experts:
raise ValueError(
......@@ -226,6 +222,7 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
......@@ -334,6 +331,7 @@ class DeepseekV2MoE(nn.Module):
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
......@@ -374,7 +372,7 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states
def _forward_shared_experts(self, hidden_states):
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
if self.n_share_experts_fusion == 0:
return self.shared_experts(hidden_states)
else:
return None
......@@ -1346,24 +1344,21 @@ class DeepseekV2ForCausalLM(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
global_server_args_dict.get("disable_shared_experts_fusion", False)
or self.config.architectures[0] != "DeepseekV3ForCausalLM"
or self.config.n_routed_experts != 256
or self.config.routed_scaling_factor != 2.5
):
self.n_share_experts_fusion = None
global_server_args_dict["n_share_experts_fusion"] = None
logger.info(
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
)
elif self.n_share_experts_fusion is None:
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
self.n_share_experts_fusion = self.tp_size
logger.info(
f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
)
if self.n_share_experts_fusion > 0:
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
self.config.architectures[0] != "DeepseekV3ForCausalLM"
or self.config.n_routed_experts != 256
):
self.n_share_experts_fusion = 0
global_server_args_dict["n_share_experts_fusion"] = 0
logger.info(
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
)
else:
assert (
self.n_share_experts_fusion == self.tp_size
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
......@@ -1484,7 +1479,7 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
if self.n_share_experts_fusion > 0:
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config.get_name() == "w8a8_int8":
......@@ -1543,12 +1538,7 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts
+ (
self.n_share_experts_fusion
if self.n_share_experts_fusion is not None
else 0
),
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
)
params_dict = dict(self.named_parameters())
......
......@@ -183,7 +183,6 @@ class ServerArgs:
warmups: Optional[str] = None
moe_dense_tp_size: Optional[int] = None
n_share_experts_fusion: int = 0
disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
......@@ -229,9 +228,6 @@ class ServerArgs:
# GPU memory is not known yet or no GPU is available.
gpu_mem = None
if is_hip():
self.disable_shared_experts_fusion = True
# Set mem fraction static, which depends on the tensor parallelism size
if self.mem_fraction_static is None:
if self.tp_size >= 16:
......@@ -1126,13 +1122,8 @@ class ServerArgs:
"--n-share-experts-fusion",
type=int,
default=0,
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 "
"we use tp_size by default.",
)
parser.add_argument(
"--disable-shared-experts-fusion",
action="store_true",
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
"set it to tp_size can get best optimized performace.",
)
parser.add_argument(
"--disable-chunked-prefix-cache",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment