Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da786e33
Unverified
Commit
da786e33
authored
Nov 07, 2025
by
Nick Hill
Committed by
GitHub
Nov 07, 2025
Browse files
[Core] Rework handling of async scheduling config (#28250)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
18903216
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
121 additions
and
71 deletions
+121
-71
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+17
-15
vllm/config/scheduler.py
vllm/config/scheduler.py
+34
-9
vllm/config/vllm.py
vllm/config/vllm.py
+49
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-26
vllm/v1/core/sched/interface.py
vllm/v1/core/sched/interface.py
+18
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+1
-19
No files found.
tests/v1/engine/test_engine_core.py
View file @
da786e33
...
...
@@ -66,7 +66,7 @@ def test_engine_core():
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
1
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
_
=
engine_core
.
step
()
_
=
engine_core
.
step
_fn
()
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
1
...
...
@@ -75,7 +75,7 @@ def test_engine_core():
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
1
assert
len
(
engine_core
.
scheduler
.
running
)
==
1
_
=
engine_core
.
step
()
_
=
engine_core
.
step
_fn
()
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
2
...
...
@@ -85,12 +85,12 @@ def test_engine_core():
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
2
assert
len
(
engine_core
.
scheduler
.
running
)
==
2
_
=
engine_core
.
step
()
_
=
engine_core
.
step
_fn
()
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
4
# Loop through until they are all done.
while
(
outs
:
=
engine_core
.
step
()[
0
].
get
(
0
))
and
outs
.
outputs
:
while
(
outs
:
=
engine_core
.
step
_fn
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
...
...
@@ -107,7 +107,7 @@ def test_engine_core():
assert
engine_core
.
scheduler
.
has_unfinished_requests
()
assert
not
engine_core
.
scheduler
.
has_finished_requests
()
_
=
engine_core
.
step
()
_
=
engine_core
.
step
_fn
()
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
1
assert
engine_core
.
scheduler
.
has_unfinished_requests
()
...
...
@@ -119,7 +119,7 @@ def test_engine_core():
assert
not
engine_core
.
scheduler
.
has_unfinished_requests
()
assert
engine_core
.
scheduler
.
has_finished_requests
()
_
=
engine_core
.
step
()
_
=
engine_core
.
step
_fn
()
assert
not
engine_core
.
scheduler
.
has_unfinished_requests
()
assert
not
engine_core
.
scheduler
.
has_finished_requests
()
...
...
@@ -133,7 +133,7 @@ def test_engine_core():
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
2
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
_
=
engine_core
.
step
()
_
=
engine_core
.
step
_fn
()
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
2
...
...
@@ -141,7 +141,7 @@ def test_engine_core():
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
1
assert
len
(
engine_core
.
scheduler
.
running
)
==
2
_
=
engine_core
.
step
()
_
=
engine_core
.
step
_fn
()
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
3
...
...
@@ -150,7 +150,7 @@ def test_engine_core():
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
2
_
=
engine_core
.
step
()
_
=
engine_core
.
step
_fn
()
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
2
...
...
@@ -165,12 +165,12 @@ def test_engine_core():
req0
.
request_id
=
req1
.
request_id
=
"test"
engine_core
.
add_request
(
*
engine_core
.
preprocess_add_request
(
req0
))
while
(
outs
:
=
engine_core
.
s
tep
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
while
engine_core
.
s
cheduler
.
has_requests
()
:
engine_core
.
step_fn
()
engine_core
.
add_request
(
*
engine_core
.
preprocess_add_request
(
req1
))
while
(
outs
:
=
engine_core
.
s
tep
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
while
engine_core
.
s
cheduler
.
has_requests
()
:
engine_core
.
step_fn
()
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
...
...
@@ -208,8 +208,8 @@ def test_engine_core_advanced_sampling():
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
1
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
# Loop through until they are all done.
while
(
outs
:
=
engine_core
.
s
tep
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
while
engine_core
.
s
cheduler
.
has_requests
()
:
engine_core
.
step_fn
()
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
...
...
@@ -297,6 +297,8 @@ def test_engine_core_concurrent_batches():
max_num_batched_tokens
=
10
,
# Reduce startup time.
enforce_eager
=
True
,
# Test concurrent batch behaviour independently of async scheduling.
async_scheduling
=
False
,
)
vllm_config
=
engine_args
.
create_engine_config
()
with
set_default_torch_num_threads
(
1
):
...
...
vllm/config/scheduler.py
View file @
da786e33
...
...
@@ -4,7 +4,7 @@
import
hashlib
from
collections.abc
import
Callable
from
dataclasses
import
InitVar
from
typing
import
Any
,
Literal
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
cast
from
pydantic
import
Field
,
field_validator
,
model_validator
from
pydantic.dataclasses
import
dataclass
...
...
@@ -17,6 +17,10 @@ from vllm.utils import (
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS
,
)
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.interface
import
SchedulerInterface
logger
=
init_logger
(
__name__
)
...
...
@@ -120,7 +124,7 @@ class SchedulerConfig:
# scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
# (default) or "mod.custom_class".
scheduler_cls
:
str
|
type
[
object
]
=
"vllm.v1.core.sched.scheduler.Scheduler"
scheduler_cls
:
str
|
type
[
object
]
=
Field
(
default
=
None
)
"""The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
the default scheduler. Can be a class directly or the path to a class of
form "mod.custom_class"."""
...
...
@@ -132,12 +136,34 @@ class SchedulerConfig:
"""
async_scheduling
:
bool
=
False
"""
EXPERIMENTAL:
If set to True, perform async scheduling. This
may
help
reduce the CPU overheads
, leading to better latency and throughput.
However,
a
sync scheduling is currently not supported with some features such as
structured outputs,
speculative decoding
,
and pipeline parallelism.
"""If set to True, perform async scheduling. This help
s to avoid gaps in
GPU utilization
, leading to better latency and throughput.
A
sync scheduling is currently not supported with some features such as
speculative decoding and pipeline parallelism.
"""
def
get_scheduler_cls
(
self
)
->
type
[
"SchedulerInterface"
]:
if
self
.
scheduler_cls
is
None
:
if
self
.
async_scheduling
:
from
vllm.v1.core.sched.async_scheduler
import
AsyncScheduler
return
AsyncScheduler
from
vllm.v1.core.sched.scheduler
import
Scheduler
return
Scheduler
# This warning can be removed once the Scheduler interface is
# finalized and we can maintain support for scheduler classes that
# implement it
logger
.
warning_once
(
"Using custom scheduler class %s. This scheduler interface is "
"not public and compatibility may not be maintained."
,
self
.
scheduler_cls
,
)
if
not
isinstance
(
self
.
scheduler_cls
,
str
):
return
cast
(
type
[
"SchedulerInterface"
],
self
.
scheduler_cls
)
return
resolve_obj_by_qualname
(
self
.
scheduler_cls
)
def
compute_hash
(
self
)
->
str
:
"""
WARNING: Whenever a new field is added to this config,
...
...
@@ -161,6 +187,8 @@ class SchedulerConfig:
"max_num_seqs"
,
"max_model_len"
,
"enable_chunked_prefill"
,
"scheduler_cls"
,
"async_scheduling"
,
mode
=
"wrap"
,
)
@
classmethod
...
...
@@ -242,9 +270,6 @@ class SchedulerConfig:
self
.
long_prefill_token_threshold
,
)
if
self
.
async_scheduling
:
self
.
scheduler_cls
=
"vllm.v1.core.sched.async_scheduler.AsyncScheduler"
@
model_validator
(
mode
=
"after"
)
def
_verify_args
(
self
)
->
Self
:
if
(
...
...
vllm/config/vllm.py
View file @
da786e33
...
...
@@ -353,6 +353,53 @@ class VllmConfig:
self
.
model_config
,
self
.
load_config
)
executor_backend
=
self
.
parallel_config
.
distributed_executor_backend
executor_supports_async_sched
=
executor_backend
in
(
"mp"
,
"uni"
,
"external_launcher"
,
)
if
self
.
scheduler_config
.
async_scheduling
:
# Async scheduling explicitly enabled, hard fail any incompatibilities.
if
self
.
parallel_config
.
pipeline_parallel_size
>
1
:
raise
ValueError
(
"Async scheduling is not yet compatible with "
"pipeline_parallel_size > 1."
)
if
self
.
speculative_config
is
not
None
:
raise
ValueError
(
"Async scheduling is not yet compatible with speculative decoding."
)
if
not
executor_supports_async_sched
:
raise
ValueError
(
"Currently, async scheduling only supports `mp`, `uni`, or "
"`external_launcher` distributed executor backend, but you chose "
f
"`
{
executor_backend
}
`."
)
elif
self
.
scheduler_config
.
async_scheduling
is
None
:
# Enable async scheduling unless there is an incompatible option.
# NOTE: we won't reach here until async scheduling is enabled by default.
if
(
self
.
parallel_config
.
pipeline_parallel_size
>
1
or
self
.
speculative_config
is
not
None
):
logger
.
warning
(
"Async scheduling is not yet supported with speculative decoding "
" or pipeline_parallel_size > 1 and will be disabled."
)
self
.
scheduler_config
.
async_scheduling
=
False
elif
not
executor_supports_async_sched
:
logger
.
warning
(
"Async scheduling will be disabled because it is not supported "
"with the `%s` distributed executor backend (only `mp`, `uni`, and "
"`external_launcher` are supported)."
,
executor_backend
,
)
self
.
scheduler_config
.
async_scheduling
=
False
else
:
self
.
scheduler_config
.
async_scheduling
=
True
from
vllm.platforms
import
current_platform
if
(
...
...
@@ -467,7 +514,7 @@ class VllmConfig:
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
use_eagle
()
):
raise
NotImplemented
Error
(
raise
Value
Error
(
"Fast prefill optimization for KV sharing is not "
"compatible with EAGLE as EAGLE requires correct logits "
"for all tokens while fast prefill gives incorrect logits "
...
...
@@ -491,7 +538,7 @@ class VllmConfig:
)
if
not
getattr
(
self
.
model_config
.
hf_config
,
"is_causal"
,
True
):
disable_chunked_prefill_reasons
.
append
(
"Only models using causal attention support
s
chunked "
"Only models using causal attention support chunked "
"prefill and prefix caching; disabling both."
)
elif
self
.
model_config
.
is_encoder_decoder
:
...
...
vllm/engine/arg_utils.py
View file @
da786e33
...
...
@@ -513,7 +513,7 @@ class EngineArgs:
ObservabilityConfig
.
collect_detailed_traces
)
scheduling_policy
:
SchedulerPolicy
=
SchedulerConfig
.
policy
scheduler_cls
:
str
|
type
[
object
]
=
SchedulerConfig
.
scheduler_cls
scheduler_cls
:
str
|
type
[
object
]
|
None
=
SchedulerConfig
.
scheduler_cls
pooler_config
:
PoolerConfig
|
None
=
ModelConfig
.
pooler_config
override_pooler_config
:
dict
|
PoolerConfig
|
None
=
(
...
...
@@ -552,7 +552,7 @@ class EngineArgs:
)
"""Custom logitproc types"""
async_scheduling
:
bool
=
SchedulerConfig
.
async_scheduling
async_scheduling
:
bool
|
None
=
SchedulerConfig
.
async_scheduling
kv_sharing_fast_prefill
:
bool
=
CacheConfig
.
kv_sharing_fast_prefill
...
...
@@ -1479,20 +1479,6 @@ class EngineArgs:
else
ParallelConfig
.
data_parallel_rpc_port
)
if
self
.
async_scheduling
:
if
self
.
pipeline_parallel_size
>
1
:
raise
ValueError
(
"Async scheduling is not supported with pipeline-parallel-size > 1."
)
# Currently, async scheduling does not support speculative decoding.
# TODO(woosuk): Support it.
if
self
.
speculative_config
is
not
None
:
raise
ValueError
(
"Currently, speculative decoding is not supported with "
"async scheduling."
)
# Forward the deprecated CLI args to the EPLB config.
if
self
.
num_redundant_experts
is
not
None
:
self
.
eplb_config
.
num_redundant_experts
=
self
.
num_redundant_experts
...
...
@@ -1536,16 +1522,6 @@ class EngineArgs:
_api_process_rank
=
self
.
_api_process_rank
,
)
if
self
.
async_scheduling
and
(
parallel_config
.
distributed_executor_backend
not
in
(
"mp"
,
"uni"
,
"external_launcher"
)
):
raise
ValueError
(
"Currently, async scheduling only supports `mp`, `uni` or "
"`external_launcher` distributed executor backend, but you choose "
f
"`
{
parallel_config
.
distributed_executor_backend
}
`."
)
speculative_config
=
self
.
create_speculative_config
(
target_model_config
=
model_config
,
target_parallel_config
=
parallel_config
,
...
...
vllm/v1/core/sched/interface.py
View file @
da786e33
...
...
@@ -4,16 +4,34 @@ from abc import ABC, abstractmethod
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Optional
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.engine
import
EngineCoreOutputs
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.structured_output
import
StructuredOutputManager
class
SchedulerInterface
(
ABC
):
@
abstractmethod
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
kv_cache_config
:
"KVCacheConfig"
,
structured_output_manager
:
"StructuredOutputManager"
,
block_size
:
int
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
include_finished_set
:
bool
=
False
,
log_stats
:
bool
=
False
,
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
schedule
(
self
)
->
"SchedulerOutput"
:
"""Schedule the requests to process in this scheduling step.
...
...
vllm/v1/engine/core.py
View file @
da786e33
...
...
@@ -29,7 +29,6 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
from
vllm.transformers_utils.config
import
maybe_register_config_serialize_by_value
from
vllm.utils.gc_utils
import
maybe_attach_gc_debug_callback
from
vllm.utils.hashing
import
get_hash_fn_by_name
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.utils.network_utils
import
make_zmq_socket
from
vllm.utils.system_utils
import
decorate_logs
,
set_process_title
from
vllm.v1.core.kv_cache_utils
import
(
...
...
@@ -41,7 +40,6 @@ from vllm.v1.core.kv_cache_utils import (
)
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
as
V1Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
...
...
@@ -117,23 +115,7 @@ class EngineCore:
self
.
structured_output_manager
=
StructuredOutputManager
(
vllm_config
)
# Setup scheduler.
if
isinstance
(
vllm_config
.
scheduler_config
.
scheduler_cls
,
str
):
Scheduler
=
resolve_obj_by_qualname
(
vllm_config
.
scheduler_config
.
scheduler_cls
)
else
:
Scheduler
=
vllm_config
.
scheduler_config
.
scheduler_cls
# This warning can be removed once the V1 Scheduler interface is
# finalized and we can maintain support for scheduler classes that
# implement it
if
Scheduler
is
not
V1Scheduler
:
logger
.
warning
(
"Using configured V1 scheduler class %s. "
"This scheduler interface is not public and "
"compatibility may not be maintained."
,
vllm_config
.
scheduler_config
.
scheduler_cls
,
)
Scheduler
=
vllm_config
.
scheduler_config
.
get_scheduler_cls
()
if
len
(
kv_cache_config
.
kv_cache_groups
)
==
0
:
# Encoder models without KV cache don't support
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment