Unverified Commit e3c4bd31 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix DeepSeek error when using DeepEP mode (#5190)

parent 5db37c86
......@@ -280,10 +280,7 @@ class DeepseekV2MoE(nn.Module):
return self.forward_deepep(hidden_states, forward_mode)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = (
......@@ -313,8 +310,7 @@ class DeepseekV2MoE(nn.Module):
):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
......@@ -364,6 +360,12 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states
def _forward_shared_experts(self, hidden_states):
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
return self.shared_experts(hidden_states)
else:
return None
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
import math
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment