Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
21fab0a3
Unverified
Commit
21fab0a3
authored
Apr 12, 2026
by
Le Yang
Committed by
GitHub
Apr 12, 2026
Browse files
fix(moe): fix RoutedExpertsCapturer assertion failure with DP>1 and MK path (#37879)
parent
3244a2eb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
103 additions
and
5 deletions
+103
-5
tests/model_executor/test_routed_experts_capture.py
tests/model_executor/test_routed_experts_capture.py
+82
-0
vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
...odel_executor/layers/fused_moe/routed_experts_capturer.py
+21
-5
No files found.
tests/model_executor/test_routed_experts_capture.py
View file @
21fab0a3
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
types
import
types
from
types
import
SimpleNamespace
from
unittest.mock
import
patch
import
pytest
import
pytest
import
torch
import
torch
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.model_executor.layers.fused_moe.config
import
RoutingMethodType
from
vllm.model_executor.layers.fused_moe.config
import
RoutingMethodType
from
vllm.model_executor.layers.fused_moe.routed_experts_capturer
import
(
RoutedExpertsCapturer
,
)
from
vllm.model_executor.layers.fused_moe.router.base_router
import
BaseRouter
from
vllm.model_executor.layers.fused_moe.router.base_router
import
BaseRouter
pytestmark
=
pytest
.
mark
.
cpu_test
pytestmark
=
pytest
.
mark
.
cpu_test
_REC_MODULE
=
"vllm.model_executor.layers.fused_moe.routed_experts_capturer"
def
_capturer_with_buffer
(
*
,
max_tokens
:
int
=
8
,
num_layers
:
int
=
4
,
num_experts_per_tok
:
int
=
2
,
dp_rank
:
int
=
0
,
)
->
RoutedExpertsCapturer
:
c
=
RoutedExpertsCapturer
()
c
.
dp_rank
=
dp_rank
c
.
_device_buffer
=
torch
.
full
(
(
max_tokens
,
num_layers
,
num_experts_per_tok
),
-
1
,
dtype
=
torch
.
int32
,
)
return
c
class
DummyRouter
(
BaseRouter
):
class
DummyRouter
(
BaseRouter
):
@
property
@
property
...
@@ -159,3 +183,61 @@ def test_gpu_model_runner_binding_stage(monkeypatch):
...
@@ -159,3 +183,61 @@ def test_gpu_model_runner_binding_stage(monkeypatch):
assert
callable
(
dummy_module
.
router
.
capture_fn
)
assert
callable
(
dummy_module
.
router
.
capture_fn
)
dummy_module
.
router
.
capture_fn
(
torch
.
tensor
([[
9
,
10
]]))
dummy_module
.
router
.
capture_fn
(
torch
.
tensor
([[
9
,
10
]]))
assert
len
(
capturer
.
calls
)
==
1
assert
len
(
capturer
.
calls
)
==
1
def
test_routed_experts_capturer_single_dp_no_metadata
():
"""dp_metadata is None: capture writes the full topk_ids rows."""
capturer
=
_capturer_with_buffer
(
dp_rank
=
0
)
topk
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
]],
dtype
=
torch
.
int32
)
ctx
=
SimpleNamespace
(
dp_metadata
=
None
)
with
patch
(
f
"
{
_REC_MODULE
}
.get_forward_context"
,
return_value
=
ctx
):
capturer
.
capture
(
layer_id
=
0
,
topk_ids
=
topk
)
assert
torch
.
equal
(
capturer
.
_device_buffer
[:
3
,
0
,
:],
topk
)
assert
capturer
.
_device_buffer
[
3
,
0
,
0
].
item
()
==
-
1
def
test_routed_experts_capturer_dp_naive_concatenated_all_ranks
():
"""n == sum(num_tokens_dp): slice this rank's segment from concatenated topk."""
capturer
=
_capturer_with_buffer
(
dp_rank
=
1
)
num_tokens_dp
=
torch
.
tensor
([
2
,
3
],
dtype
=
torch
.
int32
)
ctx
=
SimpleNamespace
(
dp_metadata
=
SimpleNamespace
(
num_tokens_across_dp_cpu
=
num_tokens_dp
)
)
# Concatenated order: rank0 rows then rank1 rows.
topk
=
torch
.
tensor
(
[[
0
,
1
],
[
2
,
3
],
[
10
,
11
],
[
12
,
13
],
[
14
,
15
]],
dtype
=
torch
.
int32
)
with
patch
(
f
"
{
_REC_MODULE
}
.get_forward_context"
,
return_value
=
ctx
):
capturer
.
capture
(
layer_id
=
0
,
topk_ids
=
topk
)
want
=
topk
[
2
:
5
]
assert
torch
.
equal
(
capturer
.
_device_buffer
[:
3
,
0
,
:],
want
)
def
test_routed_experts_capturer_dp_modular_local_tokens
():
"""n == token_num_per_dp: topk is already local to this DP rank."""
capturer
=
_capturer_with_buffer
(
dp_rank
=
1
)
num_tokens_dp
=
torch
.
tensor
([
2
,
3
],
dtype
=
torch
.
int32
)
ctx
=
SimpleNamespace
(
dp_metadata
=
SimpleNamespace
(
num_tokens_across_dp_cpu
=
num_tokens_dp
)
)
topk
=
torch
.
tensor
([[
10
,
11
],
[
12
,
13
],
[
14
,
15
]],
dtype
=
torch
.
int32
)
with
patch
(
f
"
{
_REC_MODULE
}
.get_forward_context"
,
return_value
=
ctx
):
capturer
.
capture
(
layer_id
=
0
,
topk_ids
=
topk
)
assert
torch
.
equal
(
capturer
.
_device_buffer
[:
3
,
0
,
:],
topk
)
def
test_routed_experts_capturer_dp_unexpected_batch_raises
():
"""Mismatch between topk batch dim and DP layout: fail fast."""
capturer
=
_capturer_with_buffer
(
dp_rank
=
0
)
num_tokens_dp
=
torch
.
tensor
([
2
,
3
],
dtype
=
torch
.
int32
)
ctx
=
SimpleNamespace
(
dp_metadata
=
SimpleNamespace
(
num_tokens_across_dp_cpu
=
num_tokens_dp
)
)
# total=5, local=2: n=1 matches neither naive (5) nor modular (2).
topk
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int32
)
with
(
patch
(
f
"
{
_REC_MODULE
}
.get_forward_context"
,
return_value
=
ctx
),
pytest
.
raises
(
AssertionError
,
match
=
"unexpected topk_ids batch dim"
),
):
capturer
.
capture
(
layer_id
=
0
,
topk_ids
=
topk
)
assert
capturer
.
_device_buffer
[
0
,
0
,
0
].
item
()
==
-
1
vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
View file @
21fab0a3
...
@@ -176,11 +176,27 @@ class RoutedExpertsCapturer:
...
@@ -176,11 +176,27 @@ class RoutedExpertsCapturer:
end_loc
=
topk_ids
.
shape
[
0
]
end_loc
=
topk_ids
.
shape
[
0
]
token_num_per_dp
=
topk_ids
.
shape
[
0
]
token_num_per_dp
=
topk_ids
.
shape
[
0
]
else
:
# multi dp
else
:
# multi dp
token_num_per_dp
=
ctx
.
dp_metadata
.
num_tokens_across_dp_cpu
[
self
.
dp_rank
]
num_tokens_dp
=
ctx
.
dp_metadata
.
num_tokens_across_dp_cpu
cumsum
=
torch
.
cumsum
(
ctx
.
dp_metadata
.
num_tokens_across_dp_cpu
,
dim
=
0
)
token_num_per_dp
=
int
(
num_tokens_dp
[
self
.
dp_rank
].
item
())
assert
cumsum
[
-
1
]
==
topk_ids
.
shape
[
0
]
total
=
int
(
num_tokens_dp
.
sum
().
item
())
end_loc
=
cumsum
[
self
.
dp_rank
]
n
=
topk_ids
.
shape
[
0
]
start_loc
=
end_loc
-
token_num_per_dp
if
n
==
total
:
# Naive dispatch: all DP ranks' tokens concatenated before routing.
cumsum
=
torch
.
cumsum
(
num_tokens_dp
,
dim
=
0
)
end_loc
=
int
(
cumsum
[
self
.
dp_rank
].
item
())
start_loc
=
end_loc
-
token_num_per_dp
elif
n
==
token_num_per_dp
:
# Modular-kernel path: DP combine happens inside quant_method.apply;
# select_experts only sees this rank's tokens.
start_loc
=
0
end_loc
=
token_num_per_dp
else
:
raise
AssertionError
(
"RoutedExpertsCapturer: unexpected topk_ids batch dim "
f
"
{
n
}
(expected
{
total
}
or
{
token_num_per_dp
}
"
f
"for dp_rank=
{
self
.
dp_rank
}
)"
)
if
layer_id
>=
self
.
_device_buffer
.
shape
[
1
]:
if
layer_id
>=
self
.
_device_buffer
.
shape
[
1
]:
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment