Commit 86b5aefe authored by zhuwenwen's avatar zhuwenwen
Browse files

fix nn_moe args

parent 1072b724
...@@ -271,11 +271,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -271,11 +271,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward( return self.forward(
layer=layer, layer=layer,
x=x, x=x,
router_logits=router_logits, router_logits=router_logits,
use_nn_moe=use_nn_moe,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -294,6 +296,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -294,6 +296,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts( topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x, hidden_states=x,
...@@ -334,6 +337,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -334,6 +337,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
use_nn_moe=use_nn_moe,
) )
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment