Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3b1d440e
Unverified
Commit
3b1d440e
authored
Dec 17, 2025
by
Xinyu Chen
Committed by
GitHub
Dec 17, 2025
Browse files
CustomOp: grouped topk (#29575)
Signed-off-by:
Xinyu Chen
<
xinyu1.chen@intel.com
>
parent
a9e15c21
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
75 additions
and
14 deletions
+75
-14
tests/kernels/moe/test_grouped_topk.py
tests/kernels/moe/test_grouped_topk.py
+6
-4
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+2
-2
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+52
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+15
-8
No files found.
tests/kernels/moe/test_grouped_topk.py
View file @
3b1d440e
...
...
@@ -9,8 +9,8 @@ import pytest
import
torch
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
GroupedTopk
,
fused_grouped_topk
,
grouped_topk
,
)
from
vllm.platforms
import
current_platform
...
...
@@ -50,15 +50,17 @@ def test_grouped_topk(
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_FUSED_MOE_GROUPED_TOPK"
,
"0"
)
baseline_topk_weights
,
baseline_topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
grouped_topk
=
GroupedTopk
(
topk
=
topk
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
)
baseline_topk_weights
,
baseline_topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
e_score_correction_bias
=
e_score_correction_bias
,
)
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
3b1d440e
...
...
@@ -77,11 +77,11 @@ if HAS_TRITON:
BatchedTritonExperts
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
GroupedTopk
,
TritonExperts
,
fused_experts
,
fused_topk
,
get_config_file_name
,
grouped_topk
,
)
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
,
...
...
@@ -91,7 +91,7 @@ if HAS_TRITON:
"fused_topk"
,
"fused_experts"
,
"get_config_file_name"
,
"
g
rouped
_t
opk"
,
"
G
rouped
T
opk"
,
"cutlass_moe_fp8"
,
"cutlass_moe_fp4"
,
"cutlass_moe_w4a8_fp8"
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
3b1d440e
...
...
@@ -16,6 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from
vllm
import
_custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
)
...
...
@@ -1286,6 +1287,57 @@ def grouped_topk(
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
@
CustomOp
.
register
(
"grouped_topk"
)
class
GroupedTopk
(
CustomOp
):
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
def
__init__
(
self
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
)
->
None
:
super
().
__init__
()
self
.
native_impl
=
grouped_topk
self
.
topk
=
topk
self
.
renormalize
=
renormalize
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
scoring_func
=
scoring_func
self
.
routed_scaling_factor
=
routed_scaling_factor
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
native_impl
(
hidden_states
,
gating_output
,
self
.
topk
,
self
.
renormalize
,
self
.
num_expert_group
,
self
.
topk_group
,
self
.
scoring_func
,
self
.
routed_scaling_factor
,
e_score_correction_bias
,
)
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
forward_native
(
hidden_states
,
gating_output
,
e_score_correction_bias
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
eplb_map_to_physical_and_record
(
topk_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
3b1d440e
...
...
@@ -67,7 +67,7 @@ else:
return
topk_ids
eplb_map_to_physical_and_record
=
_eplb_map_to_physical_and_record
from
vllm.model_executor.layers.fused_moe.fused_moe
import
g
rouped
_t
opk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
G
rouped
T
opk
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
# noqa: E501
rocm_aiter_grouped_topk
,
)
...
...
@@ -1594,19 +1594,26 @@ class FusedMoE(CustomOp):
grouped_topk_impl
=
partial
(
rocm_aiter_grouped_topk
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
num_expert_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
,
scoring_func
=
self
.
scoring_func
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
else
:
grouped_topk_impl
=
grouped_topk
grouped_topk_impl
=
GroupedTopk
(
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
num_expert_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
,
scoring_func
=
self
.
scoring_func
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
topk_weights
,
topk_ids
=
grouped_topk_impl
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
num_expert_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
,
scoring_func
=
self
.
scoring_func
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
)
elif
self
.
e_score_correction_bias
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment