Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
08f425ba
Unverified
Commit
08f425ba
authored
Jan 02, 2026
by
Xinyu Chen
Committed by
GitHub
Jan 02, 2026
Browse files
CustomOp: test forward dispatch for grouped_topk (#31530)
Signed-off-by:
Xinyu Chen
<
xinyu1.chen@intel.com
>
parent
a01f2fae
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
1 deletion
+13
-1
tests/kernels/moe/test_grouped_topk.py
tests/kernels/moe/test_grouped_topk.py
+13
-1
No files found.
tests/kernels/moe/test_grouped_topk.py
View file @
08f425ba
...
@@ -8,6 +8,12 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`.
...
@@ -8,6 +8,12 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`.
import
pytest
import
pytest
import
torch
import
torch
from
vllm.config
import
(
CompilationConfig
,
VllmConfig
,
get_cached_compilation_config
,
set_current_vllm_config
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
GroupedTopk
,
GroupedTopk
,
fused_grouped_topk
,
fused_grouped_topk
,
...
@@ -41,6 +47,11 @@ def test_grouped_topk(
...
@@ -41,6 +47,11 @@ def test_grouped_topk(
routed_scaling_factor
:
float
,
routed_scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
):
):
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"all"
,
"+grouped_topk"
])
)
get_cached_compilation_config
.
cache_clear
()
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
hidden_states
=
torch
.
randn
((
n_token
,
n_hidden
),
dtype
=
dtype
,
device
=
"cuda"
)
hidden_states
=
torch
.
randn
((
n_token
,
n_hidden
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
((
n_token
,
n_expert
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
((
n_token
,
n_expert
),
dtype
=
dtype
,
device
=
"cuda"
)
...
@@ -48,7 +59,7 @@ def test_grouped_topk(
...
@@ -48,7 +59,7 @@ def test_grouped_topk(
(
n_expert
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
(
n_expert
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
)
with
monkeypatch
.
context
()
as
m
:
with
set_current_vllm_config
(
vllm_config
),
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_FUSED_MOE_GROUPED_TOPK"
,
"0"
)
m
.
setenv
(
"VLLM_USE_FUSED_MOE_GROUPED_TOPK"
,
"0"
)
grouped_topk
=
GroupedTopk
(
grouped_topk
=
GroupedTopk
(
topk
=
topk
,
topk
=
topk
,
...
@@ -58,6 +69,7 @@ def test_grouped_topk(
...
@@ -58,6 +69,7 @@ def test_grouped_topk(
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
assert
grouped_topk
.
_forward_method
.
__name__
==
"forward_cuda"
baseline_topk_weights
,
baseline_topk_ids
=
grouped_topk
(
baseline_topk_weights
,
baseline_topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
gating_output
=
gating_output
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment