Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b2c8ce57
Unverified
Commit
b2c8ce57
authored
Aug 07, 2025
by
Shu Wang
Committed by
GitHub
Aug 07, 2025
Browse files
Fix Flashinfer CUTLASS MOE Allgather (#21963)
Signed-off-by:
Shu Wang
<
shuw@nvidia.com
>
parent
a3b9c17b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
27 deletions
+71
-27
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+2
-1
vllm/forward_context.py
vllm/forward_context.py
+58
-0
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
...r/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+4
-20
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+7
-6
No files found.
vllm/distributed/device_communicators/cuda_communicator.py
View file @
b2c8ce57
...
...
@@ -236,7 +236,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
input_size
=
input_
.
size
()
if
sizes
is
not
None
:
assert
len
(
sizes
)
==
world_size
assert
input_
.
shape
[
dim
]
==
sizes
[
self
.
rank_in_group
]
assert
input_
.
shape
[
dim
]
==
sizes
[
self
.
rank_in_group
],
(
f
"
{
input_
.
shape
[
dim
]
}
!=
{
sizes
[
self
.
rank_in_group
]
}
"
)
output_size
=
(
sum
(
sizes
),
)
+
input_size
[
1
:]
else
:
output_size
=
(
input_size
[
0
]
*
world_size
,
)
+
input_size
[
1
:]
...
...
vllm/forward_context.py
View file @
b2c8ce57
...
...
@@ -26,10 +26,26 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time
:
defaultdict
=
defaultdict
(
list
)
def
_compute_chunked_local_num_tokens
(
num_tokens_across_dp_cpu
:
list
[
int
],
max_num_tokens
:
int
,
chunk_idx
:
int
)
->
list
[
int
]:
dp_size
=
len
(
num_tokens_across_dp_cpu
)
local_size
=
[
-
1
]
*
dp_size
for
i
in
range
(
dp_size
):
dp_tokens
=
num_tokens_across_dp_cpu
[
i
]
local_size
[
i
]
=
min
(
max_num_tokens
,
dp_tokens
-
(
max_num_tokens
*
chunk_idx
))
if
local_size
[
i
]
<=
0
:
local_size
[
i
]
=
1
# ensure lockstep even if done
return
local_size
@
dataclass
class
DPMetadata
:
max_tokens_across_dp_cpu
:
torch
.
Tensor
cu_tokens_across_dp_cpu
:
torch
.
Tensor
local_sizes
:
Optional
[
list
[
int
]]
=
None
@
staticmethod
def
num_tokens_across_dp
(
num_tokens
:
int
,
dp_size
:
int
,
...
...
@@ -78,6 +94,48 @@ class DPMetadata:
cu_tokens_across_dp_cpu
=
torch
.
cumsum
(
num_tokens_across_dp
,
dim
=
0
)
return
DPMetadata
(
max_tokens_across_dp_cpu
,
cu_tokens_across_dp_cpu
)
@
contextmanager
def
chunked_sizes
(
self
,
max_chunk_size_per_rank
:
int
,
chunk_idx
:
int
):
"""
Context manager to compute and temporarily set the per-rank local token
sizes for a specific chunk during chunked forward execution.
This is necessary to ensure each DP (data parallel) rank processes its
designated portion of tokens in lockstep with others, even when the
token counts are uneven or some ranks have completed their input early.
For chunked execution, we break up the total tokens on each rank into
multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
`chunk_idx`, this context manager sets `self.local_sizes` to the number
of tokens to process in that chunk on each rank.
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
to determine the chunk-wise split.
`self.local_sizes` is only valid inside the context.
Args:
max_chunk_size_per_rank: The max number of tokens each rank is
allowed to process in this chunk.
chunk_idx: The index of the chunk to compute sizes for.
"""
cu_sizes
=
self
.
cu_tokens_across_dp_cpu
num_tokens_across_dp_cpu
=
[
(
cu_sizes
[
i
]
-
cu_sizes
[
i
-
1
]).
item
()
if
i
>
0
else
cu_sizes
[
0
].
item
()
for
i
in
range
(
len
(
cu_sizes
))
]
self
.
local_sizes
=
_compute_chunked_local_num_tokens
(
num_tokens_across_dp_cpu
,
max_chunk_size_per_rank
,
chunk_idx
)
try
:
yield
self
.
local_sizes
finally
:
self
.
local_sizes
=
None
def
get_chunk_sizes_across_dp_rank
(
self
)
->
Optional
[
list
[
int
]]:
return
self
.
local_sizes
@
dataclass
class
ForwardContext
:
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
View file @
b2c8ce57
...
...
@@ -4,7 +4,6 @@ from typing import Any, Optional
import
torch
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.distributed
import
get_dp_group
from
vllm.forward_context
import
get_forward_context
...
...
@@ -14,20 +13,8 @@ from vllm.model_executor.layers.fused_moe.utils import (
from
vllm.utils.flashinfer
import
nvfp4_block_scale_interleave
def
get_local_sizes
(
local_tokens
):
cu_sizes
=
get_forward_context
().
dp_metadata
.
cu_tokens_across_dp_cpu
sizes
=
[
cu_sizes
[
0
].
item
()]
for
i
in
range
(
1
,
len
(
cu_sizes
)):
sizes
.
append
((
cu_sizes
[
i
]
-
cu_sizes
[
i
-
1
]).
item
())
max_num_tokens
=
envs
.
VLLM_MOE_DP_CHUNK_SIZE
sizes_chunked
=
[
max_num_tokens
]
*
len
(
sizes
)
if
local_tokens
<
max_num_tokens
:
# When the number of local tokens is less than max_num_tokens, all other
# ranks will also have fewer than max_num_tokens. The remaining tokens
# are accounted for as residual.
sizes_chunked
=
[
x
%
max_num_tokens
for
x
in
sizes
]
return
sizes_chunked
def
get_local_sizes
():
return
get_forward_context
().
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
class
FlashInferCutlassMoEPrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalize
):
...
...
@@ -90,7 +77,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
,
topk_ids
,
a1q
,
a1q_scale
=
\
get_dp_group
().
all_gatherv
([
topk_weights
,
topk_ids
,
a1q
,
a1q_scale
],
# noqa: E501
dim
=
0
,
sizes
=
get_local_sizes
(
local_tokens
))
sizes
=
get_local_sizes
())
a1_m
,
a1_n
=
a1q
.
shape
a1q_scale
=
nvfp4_block_scale_interleave
(
a1q_scale
)
...
...
@@ -107,8 +94,5 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
[
'use_dp'
,
'local_tokens'
])
if
use_dp
:
fused_expert_output
=
get_dp_group
().
reduce_scatterv
(
fused_expert_output
,
dim
=
0
,
sizes
=
get_local_sizes
(
local_tokens
),
)
fused_expert_output
,
dim
=
0
,
sizes
=
get_local_sizes
())
output
.
copy_
(
fused_expert_output
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
b2c8ce57
...
...
@@ -1570,15 +1570,16 @@ class FusedMoE(torch.nn.Module):
max_tokens_across_dp
=
ctx
.
dp_metadata
.
max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank
=
self
.
moe_config
.
max_num_tokens
num_tokens
=
full_hidden_states
.
size
(
0
)
for
chunk_
start_
in
range
(
0
,
max_tokens_across_dp
,
moe_dp_chunk_size_per_rank
):
for
chunk_
idx
,
chunk_start_
in
enumerate
(
range
(
0
,
max_tokens_across_dp
,
moe_dp_chunk_size_per_rank
)
)
:
chunk_start
=
chunk_start_
chunk_end
=
min
(
chunk_start
+
moe_dp_chunk_size_per_rank
,
max_tokens_across_dp
)
# clamp start and end
chunk_start
=
min
(
chunk_start
,
num_tokens
-
1
)
chunk_end
=
min
(
chunk_end
,
num_tokens
)
with
ctx
.
dp_metadata
.
chunked_sizes
(
moe_dp_chunk_size_per_rank
,
chunk_idx
):
process_chunk
(
chunk_start
,
chunk_end
,
skip_result_store
=
chunk_start_
>=
num_tokens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment