Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1d532f9d
Unverified
Commit
1d532f9d
authored
Feb 27, 2026
by
Lucas Wilkinson
Committed by
GitHub
Feb 27, 2026
Browse files
[DP] Only use DP padding when cudagraphs are actually used (#34102)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
234a65b7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
139 additions
and
127 deletions
+139
-127
tests/v1/cudagraph/test_cudagraph_dispatch.py
tests/v1/cudagraph/test_cudagraph_dispatch.py
+16
-2
vllm/config/compilation.py
vllm/config/compilation.py
+6
-2
vllm/forward_context.py
vllm/forward_context.py
+1
-2
vllm/v1/cudagraph_dispatcher.py
vllm/v1/cudagraph_dispatcher.py
+40
-22
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+48
-49
vllm/v1/worker/dp_utils.py
vllm/v1/worker/dp_utils.py
+18
-30
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+10
-20
No files found.
tests/v1/cudagraph/test_cudagraph_dispatch.py
View file @
1d532f9d
...
@@ -176,10 +176,14 @@ class TestCudagraphDispatcher:
...
@@ -176,10 +176,14 @@ class TestCudagraphDispatcher:
assert
rt_mode
==
CUDAGraphMode
.
NONE
assert
rt_mode
==
CUDAGraphMode
.
NONE
assert
key
==
BatchDescriptor
(
num_tokens
=
15
)
assert
key
==
BatchDescriptor
(
num_tokens
=
15
)
# 4. disable_full should have a fall back mode (e.g., cascade attention)
# 4. invalid_modes={FULL} should have a fall back mode
# (e.g., cascade attention)
desc_full_exact
=
BatchDescriptor
(
num_tokens
=
8
,
uniform
=
False
)
desc_full_exact
=
BatchDescriptor
(
num_tokens
=
8
,
uniform
=
False
)
rt_mode
,
key
=
dispatcher
.
dispatch
(
rt_mode
,
key
=
dispatcher
.
dispatch
(
num_tokens
=
8
,
uniform_decode
=
False
,
has_lora
=
False
,
disable_full
=
True
num_tokens
=
8
,
uniform_decode
=
False
,
has_lora
=
False
,
invalid_modes
=
{
CUDAGraphMode
.
FULL
},
)
)
if
"PIECEWISE"
in
cudagraph_mode_str
:
# string contains check
if
"PIECEWISE"
in
cudagraph_mode_str
:
# string contains check
...
@@ -188,6 +192,16 @@ class TestCudagraphDispatcher:
...
@@ -188,6 +192,16 @@ class TestCudagraphDispatcher:
else
:
else
:
assert
rt_mode
==
CUDAGraphMode
.
NONE
assert
rt_mode
==
CUDAGraphMode
.
NONE
# 5. valid_modes={NONE} always returns NONE even when keys exist
rt_mode
,
key
=
dispatcher
.
dispatch
(
num_tokens
=
8
,
uniform_decode
=
False
,
has_lora
=
False
,
valid_modes
=
{
CUDAGraphMode
.
NONE
},
)
assert
rt_mode
==
CUDAGraphMode
.
NONE
assert
key
==
BatchDescriptor
(
num_tokens
=
8
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"cudagraph_mode_str,compilation_mode,expected_modes"
,
"cudagraph_mode_str,compilation_mode,expected_modes"
,
[
[
...
...
vllm/config/compilation.py
View file @
1d532f9d
...
@@ -87,8 +87,12 @@ class CUDAGraphMode(enum.Enum):
...
@@ -87,8 +87,12 @@ class CUDAGraphMode(enum.Enum):
def
separate_routine
(
self
)
->
bool
:
def
separate_routine
(
self
)
->
bool
:
return
isinstance
(
self
.
value
,
tuple
)
return
isinstance
(
self
.
value
,
tuple
)
def
valid_runtime_modes
(
self
)
->
bool
:
@
classmethod
return
self
in
[
CUDAGraphMode
.
NONE
,
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
]
def
valid_runtime_modes
(
cls
)
->
frozenset
[
"CUDAGraphMode"
]:
return
frozenset
({
cls
.
NONE
,
cls
.
PIECEWISE
,
cls
.
FULL
})
def
is_valid_runtime_mode
(
self
)
->
bool
:
return
self
in
CUDAGraphMode
.
valid_runtime_modes
()
def
__str__
(
self
)
->
str
:
def
__str__
(
self
)
->
str
:
return
self
.
name
return
self
.
name
...
...
vllm/forward_context.py
View file @
1d532f9d
...
@@ -241,7 +241,7 @@ class ForwardContext:
...
@@ -241,7 +241,7 @@ class ForwardContext:
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
assert
self
.
cudagraph_runtime_mode
.
valid_runtime_mode
s
(),
(
assert
self
.
cudagraph_runtime_mode
.
is_
valid_runtime_mode
(),
(
f
"Invalid cudagraph runtime mode:
{
self
.
cudagraph_runtime_mode
}
"
f
"Invalid cudagraph runtime mode:
{
self
.
cudagraph_runtime_mode
}
"
)
)
...
@@ -347,7 +347,6 @@ def set_forward_context(
...
@@ -347,7 +347,6 @@ def set_forward_context(
num_tokens_unpadded
=
num_tokens
,
num_tokens_unpadded
=
num_tokens
,
parallel_config
=
vllm_config
.
parallel_config
,
parallel_config
=
vllm_config
.
parallel_config
,
allow_microbatching
=
False
,
allow_microbatching
=
False
,
allow_dp_padding
=
False
,
)
)
assert
num_tokens_across_dp
is
not
None
assert
num_tokens_across_dp
is
not
None
dp_metadata
=
DPMetadata
.
make
(
dp_metadata
=
DPMetadata
.
make
(
...
...
vllm/v1/cudagraph_dispatcher.py
View file @
1d532f9d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Set
as
AbstractSet
from
dataclasses
import
replace
from
dataclasses
import
replace
from
itertools
import
product
from
itertools
import
product
...
@@ -232,8 +233,9 @@ class CudagraphDispatcher:
...
@@ -232,8 +233,9 @@ class CudagraphDispatcher:
num_tokens
:
int
,
num_tokens
:
int
,
uniform_decode
:
bool
=
False
,
uniform_decode
:
bool
=
False
,
has_lora
:
bool
=
False
,
has_lora
:
bool
=
False
,
disable_full
:
bool
=
False
,
num_active_loras
:
int
=
0
,
num_active_loras
:
int
=
0
,
valid_modes
:
AbstractSet
[
CUDAGraphMode
]
|
None
=
None
,
invalid_modes
:
AbstractSet
[
CUDAGraphMode
]
|
None
=
None
,
)
->
tuple
[
CUDAGraphMode
,
BatchDescriptor
]:
)
->
tuple
[
CUDAGraphMode
,
BatchDescriptor
]:
"""
"""
Given conditions(e.g.,batch descriptor and if using piecewise only),
Given conditions(e.g.,batch descriptor and if using piecewise only),
...
@@ -246,15 +248,29 @@ class CudagraphDispatcher:
...
@@ -246,15 +248,29 @@ class CudagraphDispatcher:
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len).
length is uniform_decode_query_len).
has_lora: Whether LoRA is active.
has_lora: Whether LoRA is active.
disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs)
num_active_loras: Number of distinct active LoRA adapters.
num_active_loras: Number of distinct active LoRA adapters.
valid_modes: Set of cudagraph modes that are allowed. None means
all modes are allowed.
invalid_modes: Set of cudagraph modes to exclude. Subtracted from
valid_modes to compute allowed modes. (e.g., {FULL} for
features like cascade attention not supported by full
cudagraphs). None means no modes are excluded.
"""
"""
allowed_modes
=
valid_modes
or
CUDAGraphMode
.
valid_runtime_modes
()
if
invalid_modes
:
allowed_modes
-=
invalid_modes
assert
len
(
allowed_modes
)
>=
1
,
(
f
"No allowed cudagraph modes: valid_modes=
{
valid_modes
}
, "
f
"invalid_modes=
{
invalid_modes
}
"
)
if
(
if
(
not
self
.
keys_initialized
not
self
.
keys_initialized
or
self
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
or
self
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
or
num_tokens
>
self
.
compilation_config
.
max_cudagraph_capture_size
or
num_tokens
>
self
.
compilation_config
.
max_cudagraph_capture_size
or
allowed_modes
<=
{
CUDAGraphMode
.
NONE
}
):
):
return
CUDAGraphMode
.
NONE
,
BatchDescriptor
(
num_tokens
)
return
CUDAGraphMode
.
NONE
,
BatchDescriptor
(
num_tokens
)
...
@@ -281,24 +297,26 @@ class CudagraphDispatcher:
...
@@ -281,24 +297,26 @@ class CudagraphDispatcher:
num_tokens
,
uniform_decode
,
has_lora
,
effective_num_active_loras
num_tokens
,
uniform_decode
,
has_lora
,
effective_num_active_loras
)
)
# check if key exists for full cudagraph
if
CUDAGraphMode
.
FULL
in
allowed_modes
:
# For pure FULL mode, keys are registered with uniform=False.
# check if key exists for full cudagraph
batch_desc_to_check
=
batch_desc
# For pure FULL mode, keys are registered with uniform=False.
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
batch_desc_to_check
=
batch_desc
batch_desc_to_check
=
replace
(
batch_desc
,
uniform
=
False
)
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
if
(
batch_desc_to_check
=
replace
(
batch_desc
,
uniform
=
False
)
not
disable_full
if
batch_desc_to_check
in
self
.
cudagraph_keys
[
CUDAGraphMode
.
FULL
]:
and
batch_desc_to_check
in
self
.
cudagraph_keys
[
CUDAGraphMode
.
FULL
]
return
CUDAGraphMode
.
FULL
,
batch_desc_to_check
):
return
CUDAGraphMode
.
FULL
,
batch_desc_to_check
if
CUDAGraphMode
.
PIECEWISE
in
allowed_modes
:
# also check if the relaxed key exists for more "general"
# also check if the relaxed key exists for more "general"
# piecewise cudagraph
# piecewise cudagraph
batch_desc_to_check
=
replace
(
batch_desc
,
num_reqs
=
None
,
uniform
=
False
)
batch_desc_to_check
=
replace
(
batch_desc
,
num_reqs
=
None
,
uniform
=
False
)
if
batch_desc_to_check
in
self
.
cudagraph_keys
[
CUDAGraphMode
.
PIECEWISE
]:
if
batch_desc_to_check
in
self
.
cudagraph_keys
[
CUDAGraphMode
.
PIECEWISE
]:
return
CUDAGraphMode
.
PIECEWISE
,
batch_desc_to_check
return
CUDAGraphMode
.
PIECEWISE
,
batch_desc_to_check
assert
CUDAGraphMode
.
NONE
in
allowed_modes
,
(
# finally, just return no cudagraphs and a trivial batch descriptor
f
"No matching cudagraph found and NONE is not in "
f
"allowed_modes=
{
allowed_modes
}
"
)
return
CUDAGraphMode
.
NONE
,
BatchDescriptor
(
num_tokens
)
return
CUDAGraphMode
.
NONE
,
BatchDescriptor
(
num_tokens
)
def
get_capture_descs
(
self
)
->
list
[
tuple
[
CUDAGraphMode
,
list
[
BatchDescriptor
]]]:
def
get_capture_descs
(
self
)
->
list
[
tuple
[
CUDAGraphMode
,
list
[
BatchDescriptor
]]]:
...
...
vllm/v1/spec_decode/eagle.py
View file @
1d532f9d
...
@@ -448,17 +448,10 @@ class SpecDecodeBaseProposer:
...
@@ -448,17 +448,10 @@ class SpecDecodeBaseProposer:
assert
draft_indexer_metadata
is
not
None
assert
draft_indexer_metadata
is
not
None
per_layer_attn_metadata
[
layer_name
]
=
draft_indexer_metadata
per_layer_attn_metadata
[
layer_name
]
=
draft_indexer_metadata
num_tokens_dp_padded
,
num_tokens_across_dp
=
self
.
_pad_batch_across_dp
(
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
num_tokens_unpadded
=
num_tokens
,
num_tokens
_padd
ed
=
num_tokens
self
.
_determine_batch_execution_and
_padd
ing
(
num_tokens
)
)
)
cudagraph_runtime_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens_dp_padded
)
num_input_tokens
=
batch_desc
.
num_tokens
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
...
@@ -549,16 +542,9 @@ class SpecDecodeBaseProposer:
...
@@ -549,16 +542,9 @@ class SpecDecodeBaseProposer:
# Generate the remaining draft tokens.
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
draft_token_ids_list
=
[
draft_token_ids
]
batch_size_dp_padded
,
batch_size_across_dp
=
self
.
_pad_batch_across_dp
(
cudagraph_runtime_mode
,
input_batch_size
,
batch_size_across_dp
=
(
num_tokens_unpadded
=
batch_size
,
num_tokens_padded
=
batch_size
self
.
_determine_batch_execution_and_padding
(
batch_size
)
)
cudagraph_runtime_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
batch_size_dp_padded
)
)
input_batch_size
=
batch_desc
.
num_tokens
if
batch_size_across_dp
is
not
None
:
batch_size_across_dp
[
self
.
dp_rank
]
=
input_batch_size
common_attn_metadata
.
num_actual_tokens
=
batch_size
common_attn_metadata
.
num_actual_tokens
=
batch_size
common_attn_metadata
.
max_query_len
=
1
common_attn_metadata
.
max_query_len
=
1
...
@@ -1568,19 +1554,11 @@ class SpecDecodeBaseProposer:
...
@@ -1568,19 +1554,11 @@ class SpecDecodeBaseProposer:
self
.
num_speculative_tokens
if
not
is_graph_capturing
else
1
self
.
num_speculative_tokens
if
not
is_graph_capturing
else
1
):
):
if
fwd_idx
<=
1
:
if
fwd_idx
<=
1
:
num_tokens_dp_padded
,
num_tokens_across_dp
=
self
.
_pad_batch_across_dp
(
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
num_tokens_unpadded
=
num_tokens
,
num_tokens_padded
=
num_tokens
self
.
_determine_batch_execution_and_padding
(
)
num_tokens
,
use_cudagraphs
=
use_cudagraphs
if
use_cudagraphs
:
cudagraph_runtime_mode
,
batch_desc
=
(
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens_dp_padded
)
)
)
num_input_tokens
=
batch_desc
.
num_tokens
)
else
:
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
num_input_tokens
=
num_tokens_dp_padded
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
# Make sure to use EAGLE's own buffer during cudagraph capture.
# Make sure to use EAGLE's own buffer during cudagraph capture.
if
(
if
(
...
@@ -1680,28 +1658,49 @@ class SpecDecodeBaseProposer:
...
@@ -1680,28 +1658,49 @@ class SpecDecodeBaseProposer:
==
1
==
1
),
"All drafting layers should belong to the same kv cache group"
),
"All drafting layers should belong to the same kv cache group"
def
_
pad_batch_across_dp
(
def
_
determine_batch_execution_and_padding
(
self
,
self
,
num_tokens_unpadded
:
int
,
num_tokens
:
int
,
num_tokens_padded
:
int
,
use_cudagraphs
:
bool
=
True
,
)
->
tuple
[
int
,
torch
.
Tensor
]:
)
->
tuple
[
CUDAGraphMode
,
int
,
torch
.
Tensor
|
None
]:
# TODO(Flechman): support DBO ubatching
cudagraph_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
should_ubatch
,
num_toks_across_dp
,
_
=
coordinate_batch_across_dp
(
num_tokens
,
num_tokens_unpadded
=
num_tokens_unpadded
,
valid_modes
=
({
CUDAGraphMode
.
NONE
}
if
not
use_cudagraphs
else
None
),
parallel_config
=
self
.
vllm_config
.
parallel_config
,
allow_microbatching
=
False
,
allow_dp_padding
=
self
.
cudagraph_dispatcher
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
,
num_tokens_padded
=
num_tokens_padded
,
uniform_decode
=
None
,
num_scheduled_tokens_per_request
=
None
,
)
)
assert
not
should_ubatch
,
"DBO ubatching not implemented for EAGLE"
num_tokens_padded
=
batch_desc
.
num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch
,
num_tokens_across_dp
=
False
,
None
if
self
.
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
should_ubatch
,
num_tokens_across_dp
,
synced_cudagraph_mode
=
(
coordinate_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
,
parallel_config
=
self
.
vllm_config
.
parallel_config
,
allow_microbatching
=
False
,
num_tokens_padded
=
num_tokens_padded
,
cudagraph_mode
=
cudagraph_mode
.
value
,
)
)
assert
not
should_ubatch
,
"DBO ubatching not implemented for EAGLE"
# Extract DP-synced values
if
num_tokens_across_dp
is
not
None
:
dp_rank
=
self
.
dp_rank
num_tokens_padded
=
int
(
num_tokens_across_dp
[
dp_rank
].
item
())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens_padded
,
valid_modes
=
{
CUDAGraphMode
(
synced_cudagraph_mode
)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert
batch_desc
.
num_tokens
==
num_tokens_padded
num_tokens_across_dp
[
dp_rank
]
=
num_tokens_padded
num_tokens_dp_padded
=
num_tokens_padded
return
cudagraph_mode
,
num_tokens_padded
,
num_tokens_across_dp
if
num_toks_across_dp
is
not
None
:
num_tokens_dp_padded
=
int
(
num_toks_across_dp
[
self
.
dp_rank
].
item
())
return
num_tokens_dp_padded
,
num_toks_across_dp
class
EagleProposer
(
SpecDecodeBaseProposer
):
class
EagleProposer
(
SpecDecodeBaseProposer
):
...
...
vllm/v1/worker/dp_utils.py
View file @
1d532f9d
...
@@ -37,7 +37,6 @@ def _get_device_and_group(parallel_config: ParallelConfig):
...
@@ -37,7 +37,6 @@ def _get_device_and_group(parallel_config: ParallelConfig):
def
_run_ar
(
def
_run_ar
(
should_ubatch
:
bool
,
should_ubatch
:
bool
,
should_dp_pad
:
bool
,
orig_num_tokens_per_ubatch
:
int
,
orig_num_tokens_per_ubatch
:
int
,
padded_num_tokens_per_ubatch
:
int
,
padded_num_tokens_per_ubatch
:
int
,
cudagraph_mode
:
int
,
cudagraph_mode
:
int
,
...
@@ -46,12 +45,11 @@ def _run_ar(
...
@@ -46,12 +45,11 @@ def _run_ar(
dp_size
=
parallel_config
.
data_parallel_size
dp_size
=
parallel_config
.
data_parallel_size
dp_rank
=
parallel_config
.
data_parallel_rank
dp_rank
=
parallel_config
.
data_parallel_rank
device
,
group
=
_get_device_and_group
(
parallel_config
)
device
,
group
=
_get_device_and_group
(
parallel_config
)
tensor
=
torch
.
zeros
(
5
,
dp_size
,
device
=
device
,
dtype
=
torch
.
int32
)
tensor
=
torch
.
zeros
(
4
,
dp_size
,
device
=
device
,
dtype
=
torch
.
int32
)
tensor
[
0
][
dp_rank
]
=
orig_num_tokens_per_ubatch
tensor
[
0
][
dp_rank
]
=
orig_num_tokens_per_ubatch
tensor
[
1
][
dp_rank
]
=
padded_num_tokens_per_ubatch
tensor
[
1
][
dp_rank
]
=
padded_num_tokens_per_ubatch
tensor
[
2
][
dp_rank
]
=
1
if
should_ubatch
else
0
tensor
[
2
][
dp_rank
]
=
1
if
should_ubatch
else
0
tensor
[
3
][
dp_rank
]
=
1
if
should_dp_pad
else
0
tensor
[
3
][
dp_rank
]
=
cudagraph_mode
tensor
[
4
][
dp_rank
]
=
cudagraph_mode
dist
.
all_reduce
(
tensor
,
group
=
group
)
dist
.
all_reduce
(
tensor
,
group
=
group
)
return
tensor
return
tensor
...
@@ -97,14 +95,13 @@ def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
...
@@ -97,14 +95,13 @@ def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
If any rank has NONE (0), all ranks use NONE.
If any rank has NONE (0), all ranks use NONE.
This ensures all ranks send consistent values (all padded or all unpadded).
This ensures all ranks send consistent values (all padded or all unpadded).
"""
"""
return
int
(
tensor
[
4
,
:].
min
().
item
())
return
int
(
tensor
[
3
,
:].
min
().
item
())
def
_synchronize_dp_ranks
(
def
_synchronize_dp_ranks
(
num_tokens_unpadded
:
int
,
num_tokens_unpadded
:
int
,
num_tokens_padded
:
int
,
num_tokens_padded
:
int
,
should_attempt_ubatching
:
bool
,
should_attempt_ubatching
:
bool
,
should_attempt_dp_padding
:
bool
,
cudagraph_mode
:
int
,
cudagraph_mode
:
int
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
)
->
tuple
[
bool
,
torch
.
Tensor
|
None
,
int
]:
)
->
tuple
[
bool
,
torch
.
Tensor
|
None
,
int
]:
...
@@ -113,8 +110,8 @@ def _synchronize_dp_ranks(
...
@@ -113,8 +110,8 @@ def _synchronize_dp_ranks(
run with microbatching or none of them do.
run with microbatching or none of them do.
2. Determines the total number of tokens that each rank will run.
2. Determines the total number of tokens that each rank will run.
When running microbatched or if
should_attempt_dp_padding is True, all
When running microbatched or if
cudagraph is enabled (synced across ranks),
ranks will be padded out so that the run with the same number of tokens
all
ranks will be padded out so that the
y
run with the same number of tokens
.
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
...
@@ -133,29 +130,26 @@ def _synchronize_dp_ranks(
...
@@ -133,29 +130,26 @@ def _synchronize_dp_ranks(
# will run and if we are using ubatching or not.
# will run and if we are using ubatching or not.
tensor
=
_run_ar
(
tensor
=
_run_ar
(
should_ubatch
=
should_attempt_ubatching
,
should_ubatch
=
should_attempt_ubatching
,
should_dp_pad
=
should_attempt_dp_padding
,
orig_num_tokens_per_ubatch
=
num_tokens_unpadded
,
orig_num_tokens_per_ubatch
=
num_tokens_unpadded
,
padded_num_tokens_per_ubatch
=
num_tokens_padded
,
padded_num_tokens_per_ubatch
=
num_tokens_padded
,
cudagraph_mode
=
cudagraph_mode
,
cudagraph_mode
=
cudagraph_mode
,
parallel_config
=
parallel_config
,
parallel_config
=
parallel_config
,
)
)
should_dp_pad
=
bool
(
torch
.
all
(
tensor
[
3
]
==
1
).
item
())
# Synchronize cudagraph_mode across ranks first (take min).
# This is needed before DP padding decision since we use the synced
#
DP ranks should all have the same value for should_attempt_dp_padding
.
#
cudagraph mode to determine whether DP padding is needed
.
assert
should_attempt_dp_padding
==
should_dp_pad
synced_cudagraph_mode
=
_post_process_cudagraph_mode
(
tensor
)
# Check conditions for microbatching
# Check conditions for microbatching
should_ubatch
=
_post_process_ubatch
(
tensor
,
parallel_config
.
num_ubatches
)
should_ubatch
=
_post_process_ubatch
(
tensor
,
parallel_config
.
num_ubatches
)
if
should_ubatch
and
not
should_dp_pad
:
# DP padding is needed when cudagraph is enabled (synced across ranks)
logger
.
debug_once
(
# or when ubatching/DBO is active (ubatching requires uniform batch
"Microbatching has been triggered and requires DP padding. "
# sizes across DP ranks currently).
"Enabling DP padding even though it has been explicitly "
# Use the synced runtime cudagraph mode rather than the compilation config
"disabled."
,
# so we can avoid padding when cudagraph is not enabled for this step.
scope
=
"global"
,
should_dp_pad
=
synced_cudagraph_mode
!=
0
or
should_ubatch
)
should_dp_pad
=
True
# Pad all DP ranks up to the maximum token count across ranks if
# Pad all DP ranks up to the maximum token count across ranks if
# should_dp_pad is True
# should_dp_pad is True
...
@@ -164,16 +158,12 @@ def _synchronize_dp_ranks(
...
@@ -164,16 +158,12 @@ def _synchronize_dp_ranks(
should_dp_pad
,
should_dp_pad
,
)
)
# Synchronize cudagraph_mode across ranks (take min)
synced_cudagraph_mode
=
_post_process_cudagraph_mode
(
tensor
)
return
should_ubatch
,
num_tokens_after_padding
,
synced_cudagraph_mode
return
should_ubatch
,
num_tokens_after_padding
,
synced_cudagraph_mode
def
coordinate_batch_across_dp
(
def
coordinate_batch_across_dp
(
num_tokens_unpadded
:
int
,
num_tokens_unpadded
:
int
,
allow_microbatching
:
bool
,
allow_microbatching
:
bool
,
allow_dp_padding
:
bool
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
num_tokens_padded
:
int
|
None
=
None
,
num_tokens_padded
:
int
|
None
=
None
,
uniform_decode
:
bool
|
None
=
None
,
uniform_decode
:
bool
|
None
=
None
,
...
@@ -187,7 +177,6 @@ def coordinate_batch_across_dp(
...
@@ -187,7 +177,6 @@ def coordinate_batch_across_dp(
Args:
Args:
num_tokens_unpadded: Number of tokens without accounting for padding
num_tokens_unpadded: Number of tokens without accounting for padding
allow_microbatching: If microbatching should be attempted
allow_microbatching: If microbatching should be attempted
allow_dp_padding: If all DP ranks should be padded up to the same value
parallel_config: The parallel config
parallel_config: The parallel config
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
TP, etc)
TP, etc)
...
@@ -195,15 +184,15 @@ def coordinate_batch_across_dp(
...
@@ -195,15 +184,15 @@ def coordinate_batch_across_dp(
only contains single token decodes
only contains single token decodes
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
number of tokens per request.
number of tokens per request.
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL).
DP padding is enabled when synced cudagraph mode across ranks is not NONE.
Returns: tuple[
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
ubatch_slices: if this is set then all DP ranks have agreed to
microbatch
microbatch
num_tokens_after_padding: A tensor containing the total number of
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
tokens per-microbatch for each DP rank including padding. Will be
padded up to the max value across all DP ranks when allow_dp_padding
padded up to the max value across all DP ranks when cudagraph is enabled.
is True.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
]
...
@@ -231,7 +220,6 @@ def coordinate_batch_across_dp(
...
@@ -231,7 +220,6 @@ def coordinate_batch_across_dp(
num_tokens_unpadded
,
num_tokens_unpadded
,
num_tokens_padded
,
num_tokens_padded
,
should_attempt_ubatching
,
should_attempt_ubatching
,
allow_dp_padding
,
cudagraph_mode
,
cudagraph_mode
,
parallel_config
,
parallel_config
,
)
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
1d532f9d
...
@@ -2300,7 +2300,7 @@ class GPUModelRunner(
...
@@ -2300,7 +2300,7 @@ class GPUModelRunner(
)
)
# Dispatch for the decoder portion of the model.
# Dispatch for the decoder portion of the model.
_
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
_
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_logits
,
disable_full
=
True
num_logits
,
invalid_modes
=
{
CUDAGraphMode
.
FULL
}
)
)
num_logits_padded
=
batch_desc
.
num_tokens
num_logits_padded
=
batch_desc
.
num_tokens
logits_indices_padded
=
self
.
kv_sharing_fast_prefill_logits_indices
[
logits_indices_padded
=
self
.
kv_sharing_fast_prefill_logits_indices
[
...
@@ -3174,20 +3174,19 @@ class GPUModelRunner(
...
@@ -3174,20 +3174,19 @@ class GPUModelRunner(
has_lora
=
num_active_loras
>
0
if
force_has_lora
is
None
else
force_has_lora
has_lora
=
num_active_loras
>
0
if
force_has_lora
is
None
else
force_has_lora
num_tokens_padded
=
self
.
_pad_for_sequence_parallelism
(
num_tokens
)
num_tokens_padded
=
self
.
_pad_for_sequence_parallelism
(
num_tokens
)
dispatch_cudagraph
=
(
lambda
num_tokens
,
disable_full
:
self
.
cudagraph_dispatcher
.
dispatch
(
def
dispatch_cudagraph
(
num_tokens
,
disable_full
=
False
,
valid_modes
=
None
):
return
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
has_lora
=
has_lora
,
has_lora
=
has_lora
,
uniform_decode
=
uniform_decode
,
uniform_decode
=
uniform_decode
,
disable_full
=
disable_full
,
num_active_loras
=
num_active_loras
,
num_active_loras
=
num_active_loras
,
valid_modes
=
{
CUDAGraphMode
.
NONE
}
if
force_eager
else
valid_modes
,
invalid_modes
=
{
CUDAGraphMode
.
FULL
}
if
disable_full
else
None
,
)
)
if
not
force_eager
else
(
CUDAGraphMode
.
NONE
,
BatchDescriptor
(
num_tokens_padded
))
)
cudagraph_mode
,
batch_descriptor
=
dispatch_cudagraph
(
cudagraph_mode
,
batch_descriptor
=
dispatch_cudagraph
(
num_tokens_padded
,
use_cascade_attn
or
has_encoder_output
num_tokens_padded
,
disable_full
=
use_cascade_attn
or
has_encoder_output
)
)
num_tokens_padded
=
batch_descriptor
.
num_tokens
num_tokens_padded
=
batch_descriptor
.
num_tokens
if
self
.
compilation_config
.
pass_config
.
enable_sp
:
if
self
.
compilation_config
.
pass_config
.
enable_sp
:
...
@@ -3204,20 +3203,11 @@ class GPUModelRunner(
...
@@ -3204,20 +3203,11 @@ class GPUModelRunner(
# across ranks
# across ranks
should_ubatch
,
num_tokens_across_dp
=
False
,
None
should_ubatch
,
num_tokens_across_dp
=
False
,
None
if
self
.
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
if
self
.
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
# in a P/D setup and still use CUDA graphs (enabled by this padding) on the
# decoder.
allow_dp_padding
=
(
self
.
compilation_config
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
)
should_ubatch
,
num_tokens_across_dp
,
synced_cudagraph_mode
=
(
should_ubatch
,
num_tokens_across_dp
,
synced_cudagraph_mode
=
(
coordinate_batch_across_dp
(
coordinate_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
,
num_tokens_unpadded
=
num_tokens
,
parallel_config
=
self
.
parallel_config
,
parallel_config
=
self
.
parallel_config
,
allow_microbatching
=
allow_microbatching
,
allow_microbatching
=
allow_microbatching
,
allow_dp_padding
=
allow_dp_padding
,
num_tokens_padded
=
num_tokens_padded
,
num_tokens_padded
=
num_tokens_padded
,
uniform_decode
=
uniform_decode
,
uniform_decode
=
uniform_decode
,
num_scheduled_tokens_per_request
=
num_scheduled_tokens_np
,
num_scheduled_tokens_per_request
=
num_scheduled_tokens_np
,
...
@@ -3232,7 +3222,7 @@ class GPUModelRunner(
...
@@ -3232,7 +3222,7 @@ class GPUModelRunner(
# Re-dispatch with DP padding so we have the correct batch_descriptor
# Re-dispatch with DP padding so we have the correct batch_descriptor
cudagraph_mode
,
batch_descriptor
=
dispatch_cudagraph
(
cudagraph_mode
,
batch_descriptor
=
dispatch_cudagraph
(
num_tokens_padded
,
num_tokens_padded
,
disable_full
=
synced_cudagraph_mode
<=
CUDAGraphMode
.
PIECEWISE
.
value
,
valid_modes
=
{
CUDAGraphMode
(
synced_cudagraph_mode
)}
,
)
)
# Assert to make sure the agreed upon token count is correct otherwise
# Assert to make sure the agreed upon token count is correct otherwise
# num_tokens_across_dp will no-longer be valid
# num_tokens_across_dp will no-longer be valid
...
@@ -4724,7 +4714,7 @@ class GPUModelRunner(
...
@@ -4724,7 +4714,7 @@ class GPUModelRunner(
assert
(
assert
(
cudagraph_runtime_mode
is
None
cudagraph_runtime_mode
is
None
or
cudagraph_runtime_mode
.
valid_runtime_mode
s
()
or
cudagraph_runtime_mode
.
is_
valid_runtime_mode
()
)
)
# If cudagraph_mode.decode_mode() == FULL and
# If cudagraph_mode.decode_mode() == FULL and
...
@@ -5336,7 +5326,7 @@ class GPUModelRunner(
...
@@ -5336,7 +5326,7 @@ class GPUModelRunner(
):
):
assert
(
assert
(
cudagraph_runtime_mode
!=
CUDAGraphMode
.
NONE
cudagraph_runtime_mode
!=
CUDAGraphMode
.
NONE
and
cudagraph_runtime_mode
.
valid_runtime_mode
s
()
and
cudagraph_runtime_mode
.
is_
valid_runtime_mode
()
),
f
"Invalid cudagraph runtime mode:
{
cudagraph_runtime_mode
}
"
),
f
"Invalid cudagraph runtime mode:
{
cudagraph_runtime_mode
}
"
if
not
batch_descriptors
:
if
not
batch_descriptors
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment