Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
55d8073d
Unverified
Commit
55d8073d
authored
Mar 12, 2026
by
Yifan Qiao
Committed by
GitHub
Mar 13, 2026
Browse files
[Bugfix] ep_scatter kernel store-load race condition (#34991)
Signed-off-by:
Yifan Qiao
<
yifanqiao@berkeley.edu
>
parent
cd32d6f5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
3 deletions
+8
-3
vllm/model_executor/layers/fused_moe/deep_gemm_utils.py
vllm/model_executor/layers/fused_moe/deep_gemm_utils.py
+8
-3
No files found.
vllm/model_executor/layers/fused_moe/deep_gemm_utils.py
View file @
55d8073d
...
...
@@ -76,9 +76,13 @@ def _fwd_kernel_ep_scatter_1(
)
tokens_per_expert
=
round_up_128
(
tokens_per_expert
)
cumsum
=
tl
.
cumsum
(
tokens_per_expert
)
-
tokens_per_expert
tl
.
store
(
expert_start_loc
+
offset_cumsum
,
cumsum
,
mask
=
offset_cumsum
<
num_experts
)
cur_expert_start
=
tl
.
load
(
expert_start_loc
+
cur_expert
)
# Extract this block's offset from the register vector (warp shuffle,
# no global memory round-trip) then write it once to expert_start_loc.
cur_expert_start
=
tl
.
sum
(
tl
.
where
(
offset_cumsum
==
cur_expert
,
cumsum
,
tl
.
zeros_like
(
cumsum
))
)
tl
.
store
(
expert_start_loc
+
cur_expert
,
cur_expert_start
)
cur_expert_token_num
=
tl
.
load
(
num_recv_tokens_per_expert
+
cur_expert
)
m_indices_start_ptr
=
m_indices
+
cur_expert_start
...
...
@@ -87,7 +91,7 @@ def _fwd_kernel_ep_scatter_1(
# any rows in the per-expert aligned region that do not correspond to
# real tokens are left untouched here and should remain initialized to
# -1 so DeepGEMM can skip them
for
start_m
in
tl
.
range
(
0
,
cur_expert_token_num
,
BLOCK_E
,
num_stages
=
4
):
for
start_m
in
tl
.
range
(
0
,
cur_expert_token_num
,
BLOCK_E
):
offs
=
start_m
+
off_expert
mask
=
offs
<
cur_expert_token_num
tl
.
store
(
...
...
@@ -186,6 +190,7 @@ def ep_scatter(
grid
=
num_experts
assert
m_indices
.
shape
[
0
]
%
BLOCK_E
==
0
assert
expert_start_loc
.
shape
[
0
]
==
num_experts
_fwd_kernel_ep_scatter_1
[(
grid
,)](
num_recv_tokens_per_expert
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment