Unverified Commit 5aff1e93 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix Qwen3MoE missing token padding optimization (#6820)

parent 8e3797be
...@@ -66,6 +66,7 @@ def fused_topk( ...@@ -66,6 +66,7 @@ def fused_topk(
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
...@@ -91,6 +92,7 @@ def fused_topk( ...@@ -91,6 +92,7 @@ def fused_topk(
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -363,15 +365,13 @@ def select_experts( ...@@ -363,15 +365,13 @@ def select_experts(
renormalize=renormalize, renormalize=renormalize,
) )
elif custom_routing_function is None: elif custom_routing_function is None:
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk"
# Qwen3MOE uses fused_topk # Qwen3MOE uses fused_topk
topk_weights, topk_ids = fused_topk( topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize, renormalize=renormalize,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
) )
else: else:
......
...@@ -193,6 +193,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -193,6 +193,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
top_k=self.top_k, top_k=self.top_k,
use_grouped_topk=False, use_grouped_topk=False,
renormalize=self.renormalize, renormalize=self.renormalize,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id, layer_id=self.layer_id,
), ),
...@@ -260,6 +261,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -260,6 +261,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
top_k=self.top_k, top_k=self.top_k,
use_grouped_topk=False, use_grouped_topk=False,
renormalize=self.renormalize, renormalize=self.renormalize,
num_token_non_padded=state.forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id, layer_id=self.layer_id,
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment