Unverified Commit 87f1b8ca authored by Xinyu Chen's avatar Xinyu Chen Committed by GitHub
Browse files

CustomOp: Unify aiter impl into GroupedTopk (#31221)


Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
parent 887e900b
......@@ -35,6 +35,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
......@@ -1295,6 +1298,7 @@ class GroupedTopk(CustomOp):
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
num_fused_shared_experts: int = 0,
) -> None:
super().__init__()
self.native_impl = grouped_topk
......@@ -1304,6 +1308,7 @@ class GroupedTopk(CustomOp):
self.topk_group = topk_group
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.num_fused_shared_experts = num_fused_shared_experts
def forward_native(
self,
......@@ -1333,6 +1338,32 @@ class GroupedTopk(CustomOp):
hidden_states, gating_output, e_score_correction_bias
)
def forward_hip(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
return rocm_aiter_grouped_topk(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
self.num_fused_shared_experts,
)
else:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record(
......
......@@ -4,7 +4,6 @@
from collections.abc import Callable, Iterable
from contextlib import nullcontext
from enum import Enum
from functools import partial
from typing import Literal, cast, get_args, overload
import torch
......@@ -67,9 +66,6 @@ else:
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
......@@ -1583,28 +1579,15 @@ class FusedMoE(CustomOp):
elif self.use_grouped_topk and valid_grouping():
assert self.topk_group is not None
assert self.num_expert_group is not None
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
grouped_topk_impl = partial(
rocm_aiter_grouped_topk,
num_fused_shared_experts=self.num_fused_shared_experts,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
)
else:
grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
)
grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
num_fused_shared_experts=self.num_fused_shared_experts,
)
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment