Commit 86b5aefe authored by zhuwenwen's avatar zhuwenwen
Browse files

fix nn_moe args

parent 1072b724
......@@ -271,11 +271,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
layer=layer,
x=x,
router_logits=router_logits,
use_nn_moe=use_nn_moe,
)
def get_fused_moe_quant_config(
......@@ -294,6 +296,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
......@@ -334,6 +337,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
use_nn_moe=use_nn_moe,
)
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment