"vllm/vscode:/vscode.git/clone" did not exist on "afb050b29d0cac27c32c19c8206a9ac2a4662de2"
Unverified Commit 87f1b8ca authored by Xinyu Chen's avatar Xinyu Chen Committed by GitHub
Browse files

CustomOp: Unify aiter impl into GroupedTopk (#31221)


Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
parent 887e900b
...@@ -35,6 +35,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ...@@ -35,6 +35,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
) )
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
...@@ -1295,6 +1298,7 @@ class GroupedTopk(CustomOp): ...@@ -1295,6 +1298,7 @@ class GroupedTopk(CustomOp):
topk_group: int = 0, topk_group: int = 0,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
num_fused_shared_experts: int = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.native_impl = grouped_topk self.native_impl = grouped_topk
...@@ -1304,6 +1308,7 @@ class GroupedTopk(CustomOp): ...@@ -1304,6 +1308,7 @@ class GroupedTopk(CustomOp):
self.topk_group = topk_group self.topk_group = topk_group
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.num_fused_shared_experts = num_fused_shared_experts
def forward_native( def forward_native(
self, self,
...@@ -1333,6 +1338,32 @@ class GroupedTopk(CustomOp): ...@@ -1333,6 +1338,32 @@ class GroupedTopk(CustomOp):
hidden_states, gating_output, e_score_correction_bias hidden_states, gating_output, e_score_correction_bias
) )
def forward_hip(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
return rocm_aiter_grouped_topk(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
self.num_fused_shared_experts,
)
else:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record( def eplb_map_to_physical_and_record(
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from functools import partial
from typing import Literal, cast, get_args, overload from typing import Literal, cast, get_args, overload
import torch import torch
...@@ -67,9 +66,6 @@ else: ...@@ -67,9 +66,6 @@ else:
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
if current_platform.is_tpu(): if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas from .moe_pallas import fused_moe as fused_moe_pallas
...@@ -1583,20 +1579,6 @@ class FusedMoE(CustomOp): ...@@ -1583,20 +1579,6 @@ class FusedMoE(CustomOp):
elif self.use_grouped_topk and valid_grouping(): elif self.use_grouped_topk and valid_grouping():
assert self.topk_group is not None assert self.topk_group is not None
assert self.num_expert_group is not None assert self.num_expert_group is not None
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert self.num_fused_shared_experts == 0
grouped_topk_impl = partial(
rocm_aiter_grouped_topk,
num_fused_shared_experts=self.num_fused_shared_experts,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
)
else:
grouped_topk_impl = GroupedTopk( grouped_topk_impl = GroupedTopk(
topk=self.top_k, topk=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
...@@ -1604,6 +1586,7 @@ class FusedMoE(CustomOp): ...@@ -1604,6 +1586,7 @@ class FusedMoE(CustomOp):
topk_group=self.topk_group, topk_group=self.topk_group,
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
num_fused_shared_experts=self.num_fused_shared_experts,
) )
topk_weights, topk_ids = grouped_topk_impl( topk_weights, topk_ids = grouped_topk_impl(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment