# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
topk_weights,topk_ids=ops.moe_fused_gate(
router_logits,
e_score_correction_bias,
num_expert_group,
topk_group,
top_k,
routed_scaling_factor=routed_scaling_factor,
n_share_experts_fusion=0,
)
else:
topk_weights,topk_ids=grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
...
...
@@ -926,7 +947,7 @@ class FusedMoE(torch.nn.Module):