Unverified Commit 13feffd0 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix master CI for DeepSeek (#6447)

parent e98afbe0
...@@ -141,6 +141,7 @@ class EPMoE(torch.nn.Module): ...@@ -141,6 +141,7 @@ class EPMoE(torch.nn.Module):
top_k: int, top_k: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
layer_id: int,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
renormalize: bool = True, renormalize: bool = True,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
...@@ -164,6 +165,7 @@ class EPMoE(torch.nn.Module): ...@@ -164,6 +165,7 @@ class EPMoE(torch.nn.Module):
) )
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.layer_id = layer_id
self.num_experts = num_experts self.num_experts = num_experts
assert self.num_experts % self.tp_size == 0 assert self.num_experts % self.tp_size == 0
self.num_experts_per_partition = self.num_experts // self.tp_size self.num_experts_per_partition = self.num_experts // self.tp_size
...@@ -837,6 +839,7 @@ class DeepEPMoE(EPMoE): ...@@ -837,6 +839,7 @@ class DeepEPMoE(EPMoE):
top_k: int, top_k: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
layer_id: int,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
renormalize: bool = True, renormalize: bool = True,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
...@@ -856,6 +859,7 @@ class DeepEPMoE(EPMoE): ...@@ -856,6 +859,7 @@ class DeepEPMoE(EPMoE):
top_k, top_k,
hidden_size, hidden_size,
intermediate_size, intermediate_size,
layer_id,
params_dtype, params_dtype,
renormalize, renormalize,
use_grouped_topk, use_grouped_topk,
......
...@@ -283,6 +283,7 @@ class FusedMoE(torch.nn.Module): ...@@ -283,6 +283,7 @@ class FusedMoE(torch.nn.Module):
top_k: int, top_k: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
layer_id: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False, reduce_results: bool = False,
renormalize: bool = True, renormalize: bool = True,
......
...@@ -51,7 +51,7 @@ from sglang.srt.layers.linear import ( ...@@ -51,7 +51,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
...@@ -114,7 +114,6 @@ if _is_hip: ...@@ -114,7 +114,6 @@ if _is_hip:
decode_attention_fwd_grouped_rope, decode_attention_fwd_grouped_rope,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -216,6 +215,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -216,6 +215,7 @@ class DeepseekV2MoE(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -224,6 +224,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -224,6 +224,7 @@ class DeepseekV2MoE(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts self.n_shared_experts = config.n_shared_experts
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
self.layer_id = layer_id
if self.tp_size > config.n_routed_experts: if self.tp_size > config.n_routed_experts:
raise ValueError( raise ValueError(
...@@ -244,6 +245,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -244,6 +245,7 @@ class DeepseekV2MoE(nn.Module):
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_id,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
use_grouped_topk=True, use_grouped_topk=True,
...@@ -344,6 +346,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -344,6 +346,9 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
) )
else: else:
state.topk_idx_local = torch.full( state.topk_idx_local = torch.full(
...@@ -1183,6 +1188,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1183,6 +1188,7 @@ class DeepseekV2DecoderLayer(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
layer_id=self.layer_id,
) )
else: else:
if enable_moe_dense_fully_dp(): if enable_moe_dense_fully_dp():
...@@ -1246,9 +1252,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1246,9 +1252,7 @@ class DeepseekV2DecoderLayer(nn.Module):
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
): ):
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = ( state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
self.layer_communicator.prepare_attn( self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
hidden_states, residual, state.forward_batch
)
) )
state.update( state.update(
dict( dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment