Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b34474bf
Unverified
Commit
b34474bf
authored
Jan 15, 2026
by
Wentao Ye
Committed by
GitHub
Jan 15, 2026
Browse files
[Feature] Support async scheduling + PP (#32359)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
6218034d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
16 additions
and
19 deletions
+16
-19
tests/v1/core/utils.py
tests/v1/core/utils.py
+3
-0
vllm/config/vllm.py
vllm/config/vllm.py
+1
-13
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+7
-0
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+3
-3
vllm/v1/executor/ray_executor.py
vllm/v1/executor/ray_executor.py
+2
-3
No files found.
tests/v1/core/utils.py
View file @
b34474bf
...
...
@@ -9,6 +9,7 @@ from vllm.config import (
ECTransferConfig
,
KVTransferConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
,
...
...
@@ -53,6 +54,7 @@ def create_scheduler(
num_speculative_tokens
:
int
|
None
=
None
,
skip_tokenizer_init
:
bool
=
False
,
async_scheduling
:
bool
=
False
,
pipeline_parallel_size
:
int
=
1
,
use_ec_connector
:
bool
=
False
,
ec_role
:
str
|
None
=
None
,
)
->
Scheduler
|
AsyncScheduler
:
...
...
@@ -133,6 +135,7 @@ def create_scheduler(
scheduler_config
=
scheduler_config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
pipeline_parallel_size
),
kv_transfer_config
=
kv_transfer_config
,
speculative_config
=
speculative_config
,
ec_transfer_config
=
ec_transfer_config
,
...
...
vllm/config/vllm.py
View file @
b34474bf
...
...
@@ -563,11 +563,6 @@ class VllmConfig:
if
self
.
scheduler_config
.
async_scheduling
:
# Async scheduling explicitly enabled, hard fail any incompatibilities.
if
self
.
parallel_config
.
pipeline_parallel_size
>
1
:
raise
ValueError
(
"Async scheduling is not yet compatible with "
"pipeline_parallel_size > 1."
)
# Currently, async scheduling only support eagle speculative
# decoding.
if
self
.
speculative_config
is
not
None
:
...
...
@@ -589,14 +584,7 @@ class VllmConfig:
)
elif
self
.
scheduler_config
.
async_scheduling
is
None
:
# Enable async scheduling unless there is an incompatible option.
if
self
.
parallel_config
.
pipeline_parallel_size
>
1
:
logger
.
warning_once
(
"Async scheduling is not yet supported with "
"pipeline_parallel_size > 1 and will be disabled."
,
scope
=
"local"
,
)
self
.
scheduler_config
.
async_scheduling
=
False
elif
(
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
method
not
in
get_args
(
EagleModelTypes
)
):
...
...
vllm/v1/core/sched/scheduler.py
View file @
b34474bf
...
...
@@ -283,6 +283,13 @@ class Scheduler(SchedulerInterface):
while
req_index
<
len
(
self
.
running
)
and
token_budget
>
0
:
request
=
self
.
running
[
req_index
]
# do not schedule another step for the same request while it still has
# output placeholders for PP.
# TODO: support PP + async scheduling without this limit
if
self
.
use_pp
and
request
.
num_output_placeholders
>
0
:
req_index
+=
1
continue
if
(
request
.
num_output_placeholders
>
0
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
...
...
vllm/v1/executor/multiproc_executor.py
View file @
b34474bf
...
...
@@ -411,9 +411,9 @@ class MultiprocExecutor(Executor):
@
cached_property
def
max_concurrent_batches
(
self
)
->
int
:
if
self
.
scheduler_config
.
async_schedu
lin
g
:
return
2
return
self
.
parallel_config
.
pipeline_parallel
_size
# PP requires PP-size concurrent batches to fill the pipe
lin
e.
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
return
2
if
pp_size
<=
1
and
self
.
scheduler_config
.
async_scheduling
else
pp
_size
def
_get_output_rank
(
self
)
->
int
:
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
...
...
vllm/v1/executor/ray_executor.py
View file @
b34474bf
...
...
@@ -111,9 +111,8 @@ class RayDistributedExecutor(Executor):
"""Ray distributed executor supports pipeline parallelism,
meaning that it allows PP size batches to be executed concurrently.
"""
if
self
.
scheduler_config
.
async_scheduling
:
return
2
return
self
.
parallel_config
.
pipeline_parallel_size
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
return
2
if
pp_size
<=
1
and
self
.
scheduler_config
.
async_scheduling
else
pp_size
def
shutdown
(
self
)
->
None
:
if
logger
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment