<=32# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
andis_power_of_two(correction_bias.shape[0])
):
returnmoe_fused_gate(
topk_weights,topk_ids=moe_fused_gate(
gating_output,
correction_bias,
num_expert_group,
...
...
@@ -235,6 +251,11 @@ def biased_grouped_topk(
n_share_experts_fusion,
routed_scaling_factor,
)
# TODO will fuse this into kernel, thus use slow manual operation now