Unverified Commit 0812d8dd authored by liuzhenwei's avatar liuzhenwei Committed by GitHub
Browse files

[Hardware][Gaudi][BugFix] fix arguments of hpu fused moe (#15945)


Signed-off-by: default avatarzhenwei <zhenweiliu@habana.ai>
parent bf7e3c51
...@@ -254,9 +254,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -254,9 +254,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
...@@ -472,7 +475,7 @@ class FusedMoE(torch.nn.Module): ...@@ -472,7 +475,7 @@ class FusedMoE(torch.nn.Module):
"non-grouped topk.") "non-grouped topk.")
if current_platform.is_hpu(): if current_platform.is_hpu():
from vllm_hpu_extension.ops import DynamicFusedMOE from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts) self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
# Note: get_quant_method will look at the layer's local_num_experts # Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment