Unverified Commit 342c58bc authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[BugFix]fix Qwen3 MoE call gate twice (#40664)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent fe9c3d6c
...@@ -231,11 +231,19 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -231,11 +231,19 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
if self.is_sequence_parallel: if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states) hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts) if self.experts.is_internal_router:
router_logits, _ = self.gate(hidden_states) # In this case, the gate/router runs inside the FusedMoE class
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=hidden_states
) )
else:
# Actually this will be dead code, since we always pass gate into
# FusedMoE in the current implementation. But we keep this code
# here for clarity and future flexibility.
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.is_sequence_parallel: if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment