Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0f5e0d56
Unverified
Commit
0f5e0d56
authored
May 31, 2025
by
vllmellm
Committed by
GitHub
May 31, 2025
Browse files
[FEAT][ROCm] Add AITER grouped topk for DeepSeekV2 (#18825)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
c55d8046
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
157 additions
and
16 deletions
+157
-16
tests/kernels/moe/test_rocm_aiter_topk.py
tests/kernels/moe/test_rocm_aiter_topk.py
+93
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-1
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+63
-15
No files found.
tests/kernels/moe/test_rocm_aiter_topk.py
View file @
0f5e0d56
...
...
@@ -35,6 +35,15 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
assert
callable
(
torch
.
ops
.
vllm
.
rocm_aiter_biased_grouped_topk
)
def
test_rocm_aiter_grouped_topk_custom_op_registration
():
"""Test that the custom op is correctly registered."""
# Check if the op exists in torch.ops.vllm
assert
hasattr
(
torch
.
ops
.
vllm
,
'rocm_aiter_grouped_topk'
)
# Check if the op is callable
assert
callable
(
torch
.
ops
.
vllm
.
rocm_aiter_grouped_topk
)
def
test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility
():
"""Test that the op can be used with torch.compile."""
# Create test tensors
...
...
@@ -120,3 +129,87 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
topk_ids_original
,
topk_ids_compiled
)
def
test_rocm_aiter_grouped_topk_torch_compile_compatibility
():
"""Test that the op can be used with torch.compile."""
# Create test tensors
token
=
64
expert
=
256
num_expert_group
=
8
topk
=
8
topk_group
=
4
renormalize
=
True
scoring_func
=
"softmax"
scale_factor
=
1.0
gating_output
=
torch
.
randn
((
token
,
expert
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
device
=
gating_output
.
device
topk_ids
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
int32
,
device
=
device
)
topk_weights
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
# Define a function that uses the op
def
grouped_topk_fn
(
gating_output
,
topk_weights
,
topk_ids
,
scoring_func
):
return
torch
.
ops
.
vllm
.
rocm_aiter_grouped_topk
(
gating_output
,
topk_weights
,
topk_ids
,
num_expert_group
,
topk_group
,
renormalize
,
scoring_func
,
scale_factor
)
# Verify the op's fake implementation
torch
.
library
.
opcheck
(
torch
.
ops
.
vllm
.
rocm_aiter_grouped_topk
,
(
gating_output
,
topk_weights
,
topk_ids
),
kwargs
=
{
"num_expert_group"
:
num_expert_group
,
"topk_group"
:
topk_group
,
"need_renorm"
:
renormalize
,
"scoring_func"
:
scoring_func
,
"routed_scaling_factor"
:
scale_factor
},
test_utils
=
(
"test_faketensor"
))
# Compile the function with appropriate settings
compiled_fn
=
torch
.
compile
(
grouped_topk_fn
,
fullgraph
=
True
,
backend
=
"inductor"
,
mode
=
"reduce-overhead"
,
dynamic
=
False
)
topk_weights_original
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
topk_ids_original
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
int32
,
device
=
device
)
topk_weights_compiled
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
topk_ids_compiled
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
int32
,
device
=
device
)
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
grouped_topk_fn
(
gating_output
,
topk_weights_original
,
topk_ids_original
,
scoring_func
)
compiled_fn
(
gating_output
,
topk_weights_compiled
,
topk_ids_compiled
,
scoring_func
)
# Sort the results for comparison since the order might not be deterministic
topk_ids_original
,
indices_original
=
torch
.
sort
(
topk_ids_original
)
topk_weights_original
=
torch
.
gather
(
topk_weights_original
,
1
,
indices_original
)
topk_ids_compiled
,
indices_compiled
=
torch
.
sort
(
topk_ids_compiled
)
topk_weights_compiled
=
torch
.
gather
(
topk_weights_compiled
,
1
,
indices_compiled
)
# Verify results match
assert
torch
.
allclose
(
topk_weights_original
,
topk_weights_compiled
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
topk_ids_original
,
topk_ids_compiled
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
0f5e0d56
...
...
@@ -45,7 +45,7 @@ else:
FusedMoEPrepareAndFinalize
=
None
# type: ignore
if
is_rocm_aiter_moe_enabled
():
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
# noqa: E501
rocm_aiter_
biased_
group_topk
as
grouped_topk
)
rocm_aiter_group
ed
_topk
as
grouped_topk
)
else
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
grouped_topk
if
current_platform
.
is_tpu
():
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
0f5e0d56
...
...
@@ -140,6 +140,36 @@ def rocm_aiter_biased_grouped_topk_fake(
pass
def
rocm_aiter_grouped_topk_impl
(
gating_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_expert_group
:
int
,
topk_group
:
int
,
need_renorm
:
bool
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
# mul to topk_weights
)
->
None
:
from
aiter
import
grouped_topk
grouped_topk
(
gating_output
,
topk_weights
,
topk_ids
,
num_expert_group
,
topk_group
,
need_renorm
,
scoring_func
,
routed_scaling_factor
)
def
rocm_aiter_grouped_topk_fake
(
gating_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_expert_group
:
int
,
topk_group
:
int
,
need_renorm
:
bool
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
# mul to topk_weights
)
->
None
:
pass
def
rocm_aiter_fused_moe_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
...
@@ -218,36 +248,54 @@ if current_platform.is_rocm():
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_grouped_topk"
,
op_func
=
rocm_aiter_grouped_topk_impl
,
mutates_args
=
[
"topk_weights"
,
"topk_ids"
],
fake_impl
=
rocm_aiter_grouped_topk_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
def
rocm_aiter_
biased_
group_topk
(
def
rocm_aiter_group
ed
_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
scoring_func
:
str
=
"s
igmoid
"
,
scoring_func
:
str
=
"s
oftmax
"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
scoring_func
==
"sigmoid"
,
(
"rocm_aiter_biased_group_topk only supports 'sigmoid' scoring_func."
)
assert
e_score_correction_bias
is
not
None
,
(
"'e_score_correction_bias' must not be None."
)
token
=
hidden_states
.
shape
[
0
]
device
=
hidden_states
.
device
topk_ids
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
int32
,
device
=
device
)
topk_weights
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
torch
.
ops
.
vllm
.
rocm_aiter_biased_grouped_topk
(
gating_output
,
e_score_correction_bias
,
topk_weights
,
topk_ids
,
num_expert_group
,
topk_group
,
renormalize
,
)
if
e_score_correction_bias
is
not
None
:
torch
.
ops
.
vllm
.
rocm_aiter_biased_grouped_topk
(
gating_output
,
e_score_correction_bias
,
topk_weights
,
topk_ids
,
num_expert_group
,
topk_group
,
renormalize
,
)
else
:
assert
(
scoring_func
==
"softmax"
or
scoring_func
==
"sigmoid"
)
torch
.
ops
.
vllm
.
rocm_aiter_grouped_topk
(
gating_output
,
topk_weights
,
topk_ids
,
num_expert_group
,
topk_group
,
renormalize
,
scoring_func
,
)
return
topk_weights
,
topk_ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment