Unverified Commit 2df9d40a authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Minor code cleanup refactor for DeepSeek models (#6324)

parent 8dc191f2
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from torch.nn import Module from torch.nn import Module
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.schedule_batch import global_server_args_dict
try: try:
from deep_gemm import ( from deep_gemm import (
...@@ -40,7 +41,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -40,7 +41,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
tma_align_input_scale, tma_align_input_scale,
) )
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
...@@ -1173,3 +1174,11 @@ class DeepEPMoE(EPMoE): ...@@ -1173,3 +1174,11 @@ class DeepEPMoE(EPMoE):
) )
return down_output return down_output
def get_moe_impl_class():
if global_server_args_dict["enable_deepep_moe"]:
return DeepEPMoE
if global_server_args_dict["enable_ep_moe"]:
return EPMoE
return FusedMoE
...@@ -52,7 +52,7 @@ from sglang.srt.layers.linear import ( ...@@ -52,7 +52,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE, get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -222,13 +222,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -222,13 +222,7 @@ class DeepseekV2MoE(nn.Module):
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix)) self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
MoEImpl = ( self.experts = get_moe_impl_class()(
DeepEPMoE
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
self.experts = MoEImpl(
num_experts=config.n_routed_experts + self.n_share_experts_fusion, num_experts=config.n_routed_experts + self.n_share_experts_fusion,
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -251,26 +245,19 @@ class DeepseekV2MoE(nn.Module): ...@@ -251,26 +245,19 @@ class DeepseekV2MoE(nn.Module):
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0: if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
# disable tp for shared experts when enable deepep moe # disable tp for shared experts when enable deepep moe
if not global_server_args_dict["enable_deepep_moe"]: self.shared_experts = DeepseekV2MLP(
self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size,
hidden_size=config.hidden_size, intermediate_size=intermediate_size,
intermediate_size=intermediate_size, hidden_act=config.hidden_act,
hidden_act=config.hidden_act, quant_config=quant_config,
quant_config=quant_config, reduce_results=False,
reduce_results=False, prefix=add_prefix("shared_experts", prefix),
prefix=add_prefix("shared_experts", prefix), **(
) dict(tp_rank=0, tp_size=1)
else: if global_server_args_dict["enable_deepep_moe"]
self.shared_experts = DeepseekV2MLP( else {}
hidden_size=config.hidden_size, ),
intermediate_size=intermediate_size, )
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
tp_rank=0,
tp_size=1,
)
if global_server_args_dict["enable_deepep_moe"]: if global_server_args_dict["enable_deepep_moe"]:
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
...@@ -1726,12 +1713,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1726,12 +1713,7 @@ class DeepseekV2ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
MoEImpl = ( expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
DeepEPMoE
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment