Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0cdbe7b7
Unverified
Commit
0cdbe7b7
authored
Oct 31, 2025
by
Nick Hill
Committed by
GitHub
Nov 01, 2025
Browse files
[Core] Async scheduling + structured outputs compatibility (#26866)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
df334868
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
278 additions
and
154 deletions
+278
-154
tests/conftest.py
tests/conftest.py
+3
-0
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+0
-9
tests/v1/e2e/test_async_scheduling.py
tests/v1/e2e/test_async_scheduling.py
+12
-2
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+18
-1
tests/v1/executor/test_executor.py
tests/v1/executor/test_executor.py
+3
-1
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
+0
-2
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+1
-3
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+0
-12
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+0
-12
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+18
-12
vllm/v1/core/sched/async_scheduler.py
vllm/v1/core/sched/async_scheduler.py
+8
-0
vllm/v1/core/sched/interface.py
vllm/v1/core/sched/interface.py
+7
-1
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+11
-6
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+13
-18
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+57
-14
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+26
-10
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+26
-17
vllm/v1/executor/ray_executor.py
vllm/v1/executor/ray_executor.py
+33
-4
vllm/v1/executor/ray_utils.py
vllm/v1/executor/ray_utils.py
+20
-15
vllm/v1/structured_output/utils.py
vllm/v1/structured_output/utils.py
+22
-15
No files found.
tests/conftest.py
View file @
0cdbe7b7
...
...
@@ -6,6 +6,9 @@ from copy import deepcopy
from
tblib
import
pickling_support
# Import fixture
from
tests.v1.entrypoints.conftest
import
sample_json_schema
# noqa
# ruff: noqa
# Install support for pickling exceptions so that we can nicely propagate
...
...
tests/v1/core/test_scheduler.py
View file @
0cdbe7b7
...
...
@@ -337,8 +337,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
model_output
=
ModelRunnerOutput
(
...
...
@@ -385,8 +383,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
model_output
=
ModelRunnerOutput
(
...
...
@@ -431,8 +427,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
model_output
=
ModelRunnerOutput
(
...
...
@@ -472,8 +466,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
model_output
=
ModelRunnerOutput
(
...
...
@@ -1988,7 +1980,6 @@ def test_schedule_skip_tokenizer_init():
scheduler
.
add_request
(
request
)
output
=
scheduler
.
schedule
()
assert
len
(
output
.
scheduled_new_reqs
)
==
len
(
requests
)
assert
output
.
grammar_bitmask
is
None
def
test_schedule_skip_tokenizer_init_structured_output_request
():
...
...
tests/v1/e2e/test_async_sched
_and_preempt
.py
→
tests/v1/e2e/test_async_sched
uling
.py
View file @
0cdbe7b7
...
...
@@ -7,6 +7,7 @@ import torch._dynamo.config as dynamo_config
from
vllm
import
SamplingParams
from
vllm.logprobs
import
Logprob
from
vllm.sampling_params
import
StructuredOutputsParams
from
...conftest
import
VllmRunner
from
...models.utils
import
check_outputs_equal
...
...
@@ -15,9 +16,12 @@ MODEL = "Qwen/Qwen3-0.6B"
@
dynamo_config
.
patch
(
cache_size_limit
=
16
)
def
test_preempt_and_async_scheduling_e2e
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_preempt_and_async_scheduling_e2e
(
sample_json_schema
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, and various sampling parameters."""
uni/multiproc executor, and various sampling parameters
including structured outputs."""
first_prompt
=
(
"The following numbers of the sequence "
...
...
@@ -35,6 +39,12 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
dict
(
bad_words
=
[
"the"
,
" the"
]),
dict
(
logprobs
=
2
),
dict
(
logprobs
=
2
,
presence_penalty
=-
1.0
),
dict
(
structured_outputs
=
StructuredOutputsParams
(
json
=
sample_json_schema
)),
dict
(
structured_outputs
=
StructuredOutputsParams
(
json
=
sample_json_schema
),
logprobs
=
2
,
presence_penalty
=-
1.0
,
),
]
default_params
=
dict
(
...
...
tests/v1/engine/test_engine_core.py
View file @
0cdbe7b7
...
...
@@ -248,7 +248,7 @@ def test_engine_core_concurrent_batches():
self
,
scheduler_output
,
non_block
=
False
,
)
->
Future
[
ModelRunnerOutput
]:
)
->
Future
[
ModelRunnerOutput
|
None
]:
"""Make execute_model non-blocking."""
# DummyExecutor used only for testing async case.
...
...
@@ -263,6 +263,23 @@ def test_engine_core_concurrent_batches():
# Use the thread pool instead of creating a new thread
return
self
.
thread_pool
.
submit
(
_execute
)
def
sample_tokens
(
self
,
grammar_output
,
non_block
=
False
)
->
Future
[
ModelRunnerOutput
]:
"""Make sample_tokens non-blocking."""
# DummyExecutor used only for testing async case.
assert
non_block
def
_execute
():
output
=
self
.
collective_rpc
(
"sample_tokens"
,
args
=
(
grammar_output
,))
# Make a copy because output[0] may be reused
# by the next batch.
return
copy
.
deepcopy
(
output
[
0
])
# Use the thread pool instead of creating a new thread
return
self
.
thread_pool
.
submit
(
_execute
)
@
property
def
max_concurrent_batches
(
self
)
->
int
:
return
2
...
...
tests/v1/executor/test_executor.py
View file @
0cdbe7b7
...
...
@@ -31,7 +31,9 @@ class CustomMultiprocExecutor(MultiprocExecutor):
# Drop marker to show that this was run
with
open
(
".marker"
,
"w"
):
...
return
super
().
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
return
super
().
collective_rpc
(
method
,
timeout
,
args
,
kwargs
,
non_block
,
unique_reply_rank
)
CustomMultiprocExecutorAsync
=
CustomMultiprocExecutor
...
...
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
View file @
0cdbe7b7
...
...
@@ -26,8 +26,6 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
kv_connector_metadata
=
SharedStorageConnectorMetadata
(),
)
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
0cdbe7b7
...
...
@@ -981,9 +981,7 @@ def test_scheduler_kv_connector_stats_aggregation():
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
[
0
],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
set
(),
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
free_encoder_mm_hashes
=
[],
)
engine_core_outputs
=
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
...
...
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
0cdbe7b7
...
...
@@ -92,8 +92,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
...
...
@@ -171,8 +169,6 @@ def test_update_states_request_finished(model_runner):
num_common_prefix_blocks
=
[],
finished_req_ids
=
{
req_id
},
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
...
...
@@ -201,8 +197,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
...
...
@@ -230,8 +224,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
...
...
@@ -261,8 +253,6 @@ def test_update_states_no_changes(model_runner):
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
...
...
@@ -296,8 +286,6 @@ def test_update_states_request_unscheduled(model_runner):
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
0cdbe7b7
...
...
@@ -152,8 +152,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
...
...
@@ -269,8 +267,6 @@ def test_update_states_request_finished(model_runner, dist_init):
num_common_prefix_blocks
=
[],
finished_req_ids
=
{
req_id
},
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
...
...
@@ -301,8 +297,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
...
...
@@ -330,8 +324,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
...
...
@@ -423,8 +415,6 @@ def test_update_states_no_changes(model_runner, dist_init):
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
...
...
@@ -460,8 +450,6 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
structured_output_request_ids
=
[],
grammar_bitmask
=
None
,
)
metadata_before
=
model_runner
.
_update_states
(
scheduler_output
)
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
0cdbe7b7
...
...
@@ -6,7 +6,7 @@ KV cache helper for store.
from
collections.abc
import
Sequence
from
concurrent.futures
import
CancelledError
,
Future
from
typing
import
TYPE_CHECKING
,
Literal
,
cast
from
typing
import
TYPE_CHECKING
,
Literal
import
torch
...
...
@@ -138,8 +138,11 @@ class KVOutputAggregator:
return
cls
(
connector
.
get_finished_count
()
or
world_size
)
def
aggregate
(
self
,
outputs
:
list
[
ModelRunnerOutput
],
output_rank
:
int
=
0
)
->
ModelRunnerOutput
:
self
,
outputs
:
list
[
ModelRunnerOutput
|
None
],
output_rank
:
int
=
0
)
->
ModelRunnerOutput
|
None
:
if
not
outputs
[
output_rank
]:
return
None
# Aggregate kv_connector_output from all workers
def
update_finished_set
(
...
...
@@ -161,6 +164,7 @@ class KVOutputAggregator:
aggregated_kv_connector_stats
=
None
invalid_block_ids
=
set
[
int
]()
for
model_runner_output
in
outputs
:
assert
model_runner_output
is
not
None
kv_output
=
model_runner_output
.
kv_connector_output
if
not
kv_output
:
continue
...
...
@@ -204,6 +208,7 @@ class KVOutputAggregator:
# select output of the worker specified by output_rank
output
=
outputs
[
output_rank
]
assert
output
is
not
None
output
.
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
finished_sending
or
None
,
finished_recving
=
finished_recving
or
None
,
...
...
@@ -215,13 +220,16 @@ class KVOutputAggregator:
return
output
def
async_aggregate
(
self
,
output_futures
:
Sequence
[
Future
[
ModelRunnerOutput
]],
output_rank
:
int
=
0
)
->
Future
[
ModelRunnerOutput
]:
self
,
output_futures
:
Sequence
[
Future
[
ModelRunnerOutput
|
None
]],
output_rank
:
int
=
0
,
)
->
Future
[
ModelRunnerOutput
|
None
]:
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future
:
Future
[
ModelRunnerOutput
]
=
Future
()
result_future
:
Future
[
ModelRunnerOutput
|
None
]
=
Future
()
outputs
:
list
[
ModelRunnerOutput
|
None
]
=
[
None
]
*
len
(
output_futures
)
remaining
=
len
(
output_futures
)
def
make_callback
(
idx
):
def
callback
(
fut
):
...
...
@@ -236,12 +244,10 @@ class KVOutputAggregator:
result_future
.
set_exception
(
e
)
# this check assumes io_thread_pool uses a single thread
if
all
(
outputs
):
result_future
.
set_result
(
self
.
aggregate
(
cast
(
list
[
ModelRunnerOutput
],
outputs
),
output_rank
)
)
nonlocal
remaining
remaining
-=
1
if
not
remaining
:
result_future
.
set_result
(
self
.
aggregate
(
outputs
,
output_rank
))
return
callback
...
...
vllm/v1/core/sched/async_scheduler.py
View file @
0cdbe7b7
...
...
@@ -15,8 +15,12 @@ class AsyncScheduler(Scheduler):
scheduler_output
:
SchedulerOutput
,
)
->
None
:
super
().
_update_after_schedule
(
scheduler_output
)
pending_structured_output_tokens
=
False
for
req_id
in
scheduler_output
.
num_scheduled_tokens
:
request
=
self
.
requests
[
req_id
]
pending_structured_output_tokens
|=
(
request
.
use_structured_output
and
request
.
num_output_placeholders
>
0
)
if
(
request
.
num_computed_tokens
==
request
.
num_tokens
+
request
.
num_output_placeholders
...
...
@@ -25,6 +29,10 @@ class AsyncScheduler(Scheduler):
# TODO(woosuk): Support speculative decoding.
request
.
num_output_placeholders
+=
1
scheduler_output
.
pending_structured_output_tokens
=
(
pending_structured_output_tokens
)
def
_update_request_with_output
(
self
,
request
:
Request
,
...
...
vllm/v1/core/sched/interface.py
View file @
0cdbe7b7
...
...
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
if
TYPE_CHECKING
:
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.engine
import
EngineCoreOutputs
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
...
...
@@ -40,6 +40,12 @@ class SchedulerInterface(ABC):
"""
raise
NotImplementedError
@
abstractmethod
def
get_grammar_bitmask
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
"GrammarOutput | None"
:
raise
NotImplementedError
@
abstractmethod
def
update_from_output
(
self
,
...
...
vllm/v1/core/sched/output.py
View file @
0cdbe7b7
...
...
@@ -181,12 +181,17 @@ class SchedulerOutput:
# freed from the encoder cache.
free_encoder_mm_hashes
:
list
[
str
]
# ids of structured outputs requests included in the bitmask, in the
# same order as the corresponding stacked rows of the bitmask.
# There may be more than one row per request in the case of speculative decoding.
structured_output_request_ids
:
list
[
str
]
# the bitmask for the whole batch
grammar_bitmask
:
"npt.NDArray[np.int32] | None"
# Whether the scheduled requests have all the output tokens they
# need to perform grammar bitmask computation.
pending_structured_output_tokens
:
bool
=
False
# KV Cache Connector metadata.
kv_connector_metadata
:
KVConnectorMetadata
|
None
=
None
@
dataclass
class
GrammarOutput
:
# ids of structured output requests.
structured_output_request_ids
:
list
[
str
]
# Bitmask ordered as structured_output_request_ids.
grammar_bitmask
:
"npt.NDArray[np.int32]"
vllm/v1/core/sched/scheduler.py
View file @
0cdbe7b7
...
...
@@ -5,7 +5,7 @@ import itertools
import
time
from
collections
import
defaultdict
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
Any
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
EventPublisherFactory
,
KVEventBatch
...
...
@@ -24,7 +24,12 @@ from vllm.v1.core.encoder_cache_manager import (
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
,
KVCacheManager
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.output
import
CachedRequestData
,
NewRequestData
,
SchedulerOutput
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
GrammarOutput
,
NewRequestData
,
SchedulerOutput
,
)
from
vllm.v1.core.sched.request_queue
import
SchedulingPolicy
,
create_request_queue
from
vllm.v1.core.sched.utils
import
check_stop
,
remove_all
from
vllm.v1.engine
import
EngineCoreEventType
,
EngineCoreOutput
,
EngineCoreOutputs
...
...
@@ -35,10 +40,6 @@ from vllm.v1.request import Request, RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.structured_output
import
StructuredOutputManager
if
TYPE_CHECKING
:
import
numpy
as
np
import
numpy.typing
as
npt
logger
=
init_logger
(
__name__
)
...
...
@@ -619,9 +620,6 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens
,
req_to_new_blocks
,
)
structured_output_request_ids
,
grammar_bitmask
=
self
.
get_grammar_bitmask
(
num_scheduled_tokens
.
keys
(),
scheduled_spec_decode_tokens
)
# Record the request ids that were scheduled in this step.
self
.
prev_step_scheduled_req_ids
.
clear
()
...
...
@@ -641,8 +639,6 @@ class Scheduler(SchedulerInterface):
# the previous and the current steps.
finished_req_ids
=
self
.
finished_req_ids
,
free_encoder_mm_hashes
=
self
.
encoder_cache_manager
.
get_freed_mm_hashes
(),
structured_output_request_ids
=
structured_output_request_ids
,
grammar_bitmask
=
grammar_bitmask
,
)
# NOTE(Kuntai): this function is designed for multiple purposes:
...
...
@@ -872,9 +868,8 @@ class Scheduler(SchedulerInterface):
def
get_grammar_bitmask
(
self
,
scheduled_request_ids
:
Iterable
[
str
],
scheduled_spec_decode_tokens
:
dict
[
str
,
list
[
int
]],
)
->
tuple
[
list
[
str
],
"npt.NDArray[np.int32] | None"
]:
scheduler_output
:
SchedulerOutput
,
)
->
GrammarOutput
|
None
:
# Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order.
# PERF: in case of chunked prefill,
...
...
@@ -883,18 +878,18 @@ class Scheduler(SchedulerInterface):
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids
=
[
req_id
for
req_id
in
schedule
d_request_id
s
for
req_id
in
schedule
r_output
.
num_scheduled_token
s
if
(
req
:
=
self
.
requests
.
get
(
req_id
))
and
req
.
use_structured_output
]
if
not
structured_output_request_ids
:
return
structured_output_request_ids
,
None
return
None
bitmask
=
self
.
structured_output_manager
.
grammar_bitmask
(
self
.
requests
,
structured_output_request_ids
,
scheduled_spec_decode_tokens
,
scheduler_output
.
scheduled_spec_decode_tokens
,
)
return
structured_output_request_ids
,
bitmask
return
GrammarOutput
(
structured_output_request_ids
,
bitmask
)
def
update_from_output
(
self
,
...
...
vllm/v1/engine/core.py
View file @
0cdbe7b7
...
...
@@ -12,7 +12,7 @@ from concurrent.futures import Future
from
contextlib
import
ExitStack
,
contextmanager
from
inspect
import
isclass
,
signature
from
logging
import
DEBUG
from
typing
import
Any
,
TypeVar
from
typing
import
Any
,
TypeVar
,
cast
import
msgspec
import
zmq
...
...
@@ -334,9 +334,12 @@ class EngineCore:
if
not
self
.
scheduler
.
has_requests
():
return
{},
False
scheduler_output
=
self
.
scheduler
.
schedule
()
future
=
self
.
model_executor
.
execute_model
(
scheduler_output
,
non_block
=
True
)
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
scheduler_output
)
with
self
.
log_error_detail
(
scheduler_output
):
model_output
=
self
.
model_executor
.
execute_model
(
scheduler_output
)
model_output
=
future
.
result
()
if
model_output
is
None
:
model_output
=
self
.
model_executor
.
sample_tokens
(
grammar_output
)
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
...
...
@@ -376,20 +379,47 @@ class EngineCore:
assert
len
(
batch_queue
)
<
self
.
batch_queue_size
model_executed
=
False
deferred_scheduler_output
=
None
if
self
.
scheduler
.
has_requests
():
scheduler_output
=
self
.
scheduler
.
schedule
()
future
=
self
.
model_executor
.
execute_model
(
scheduler_output
,
non_block
=
True
)
batch_queue
.
appendleft
((
future
,
scheduler_output
))
exec_
future
=
self
.
model_executor
.
execute_model
(
scheduler_output
,
non_block
=
True
)
model_executed
=
scheduler_output
.
total_num_scheduled_tokens
>
0
if
(
model_executed
and
len
(
batch_queue
)
<
self
.
batch_queue_size
and
not
batch_queue
[
-
1
][
0
].
done
()
):
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return
None
,
True
if
scheduler_output
.
pending_structured_output_tokens
:
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output
=
scheduler_output
# Block-wait for execute to return (continues running async on the GPU).
with
self
.
log_error_detail
(
scheduler_output
):
exec_result
=
exec_future
.
result
()
assert
exec_result
is
None
else
:
# We aren't waiting for any tokens, get any grammar output immediately.
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
scheduler_output
)
# Block-wait for execute to return (continues running async on the GPU).
with
self
.
log_error_detail
(
scheduler_output
):
exec_result
=
exec_future
.
result
()
if
exec_result
is
None
:
# Call sample tokens.
future
=
self
.
model_executor
.
sample_tokens
(
grammar_output
,
non_block
=
True
)
else
:
# No sampling required (e.g. all requests finished).
future
=
cast
(
Future
[
ModelRunnerOutput
],
exec_future
)
# Add this step's future to the queue.
batch_queue
.
appendleft
((
future
,
scheduler_output
))
if
(
model_executed
and
len
(
batch_queue
)
<
self
.
batch_queue_size
and
not
batch_queue
[
-
1
][
0
].
done
()
):
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return
None
,
True
elif
not
batch_queue
:
# Queue is empty. We should not reach here since this method should
...
...
@@ -405,6 +435,19 @@ class EngineCore:
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if
deferred_scheduler_output
:
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
deferred_scheduler_output
)
future
=
self
.
model_executor
.
sample_tokens
(
grammar_output
,
non_block
=
True
)
batch_queue
.
appendleft
((
future
,
deferred_scheduler_output
))
return
engine_core_outputs
,
model_executed
def
shutdown
(
self
):
...
...
vllm/v1/executor/abstract.py
View file @
0cdbe7b7
...
...
@@ -16,7 +16,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.tasks
import
SupportedTask
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.engine
import
ReconfigureDistributedRequest
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
...
...
@@ -187,28 +187,44 @@ class Executor(ABC):
@
overload
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
Literal
[
False
]
=
False
,
)
->
ModelRunnerOutput
:
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
Literal
[
False
]
=
False
)
->
ModelRunnerOutput
|
None
:
pass
@
overload
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
Literal
[
True
]
=
True
,
)
->
Future
[
ModelRunnerOutput
]:
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
Literal
[
True
]
=
True
)
->
Future
[
ModelRunnerOutput
|
None
]:
pass
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
Future
[
ModelRunnerOutput
]:
)
->
ModelRunnerOutput
|
None
|
Future
[
ModelRunnerOutput
|
None
]:
output
=
self
.
collective_rpc
(
# type: ignore[call-overload]
"execute_model"
,
args
=
(
scheduler_output
,),
non_block
=
non_block
)
return
output
[
0
]
@
overload
def
sample_tokens
(
self
,
grammar_output
:
GrammarOutput
|
None
,
non_block
:
Literal
[
False
]
=
False
)
->
ModelRunnerOutput
:
pass
@
overload
def
sample_tokens
(
self
,
grammar_output
:
GrammarOutput
|
None
,
non_block
:
Literal
[
True
]
=
True
)
->
Future
[
ModelRunnerOutput
]:
pass
def
sample_tokens
(
self
,
grammar_output
:
GrammarOutput
|
None
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
Future
[
ModelRunnerOutput
]:
output
=
self
.
collective_rpc
(
# type: ignore[call-overload]
"sample_tokens"
,
args
=
(
grammar_output
,),
non_block
=
non_block
)
return
output
[
0
]
def
execute_dummy_batch
(
self
)
->
None
:
self
.
collective_rpc
(
"execute_dummy_batch"
)
...
...
vllm/v1/executor/multiproc_executor.py
View file @
0cdbe7b7
...
...
@@ -46,7 +46,7 @@ from vllm.utils.system_utils import (
get_mp_context
,
set_process_title
,
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.executor.abstract
import
Executor
,
FailureCallback
from
vllm.v1.outputs
import
AsyncModelRunnerOutput
,
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.worker.worker_base
import
WorkerWrapperBase
...
...
@@ -132,15 +132,12 @@ class MultiprocExecutor(Executor):
uw
.
death_writer
.
close
()
self
.
_ensure_worker_termination
([
uw
.
proc
for
uw
in
unready_workers
])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
if
self
.
max_concurrent_batches
>
1
:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self
.
io_thread_pool
=
ThreadPoolExecutor
(
max_workers
=
1
,
thread_name_prefix
=
"mp_exec_io"
)
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue.
# _async_aggregate_workers_output also assumes a single IO thread.
self
.
io_thread_pool
=
ThreadPoolExecutor
(
max_workers
=
1
,
thread_name_prefix
=
"mp_exec_io"
)
self
.
output_rank
=
self
.
_get_output_rank
()
self
.
has_connector
=
self
.
vllm_config
.
kv_transfer_config
is
not
None
...
...
@@ -180,15 +177,27 @@ class MultiprocExecutor(Executor):
self
.
failure_callback
=
callback
def
execute_model
(
# type: ignore[override]
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
,
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
None
|
Future
[
ModelRunnerOutput
|
None
]:
return
self
.
_execute_with_aggregation
(
"execute_model"
,
scheduler_output
,
non_block
=
non_block
)
def
sample_tokens
(
# type: ignore[override]
self
,
grammar_output
:
GrammarOutput
|
None
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
Future
[
ModelRunnerOutput
]:
return
self
.
_execute_with_aggregation
(
# type: ignore[return-value]
"sample_tokens"
,
grammar_output
,
non_block
=
non_block
)
def
_execute_with_aggregation
(
self
,
method
:
str
,
*
args
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
None
|
Future
[
ModelRunnerOutput
|
None
]:
if
not
self
.
has_connector
:
# get output only from a single worker (output_rank)
(
output
,)
=
self
.
collective_rpc
(
"execute_model"
,
args
=
(
scheduler_output
,)
,
method
,
args
=
args
,
unique_reply_rank
=
self
.
output_rank
,
non_block
=
non_block
,
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
,
...
...
@@ -197,8 +206,8 @@ class MultiprocExecutor(Executor):
# get output from all workers
outputs
=
self
.
collective_rpc
(
"execute_model"
,
args
=
(
scheduler_output
,)
,
method
,
args
=
args
,
non_block
=
non_block
,
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
,
)
...
...
vllm/v1/executor/ray_executor.py
View file @
0cdbe7b7
...
...
@@ -19,7 +19,7 @@ from vllm.utils.network_utils import (
get_ip
,
get_open_port
,
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.engine
import
ReconfigureDistributedRequest
,
ReconfigureRankType
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.ray_utils
import
(
...
...
@@ -41,6 +41,9 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
COMPLETED_NONE_FUTURE
:
Future
[
ModelRunnerOutput
|
None
]
=
Future
()
COMPLETED_NONE_FUTURE
.
set_result
(
None
)
@
dataclass
class
RayWorkerMetaData
:
...
...
@@ -96,6 +99,8 @@ class RayDistributedExecutor(Executor):
# KV connector setup
self
.
has_connector
=
self
.
vllm_config
.
kv_transfer_config
is
not
None
self
.
scheduler_output
:
SchedulerOutput
|
None
=
None
@
property
def
max_concurrent_batches
(
self
)
->
int
:
"""Ray distributed executor supports pipeline parallelism,
...
...
@@ -381,22 +386,46 @@ class RayDistributedExecutor(Executor):
self
.
shutdown
()
def
execute_model
(
# type: ignore[override]
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
,
)
->
ModelRunnerOutput
|
None
|
Future
[
ModelRunnerOutput
|
None
]:
if
self
.
scheduler_output
is
not
None
:
raise
RuntimeError
(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
self
.
scheduler_output
=
scheduler_output
return
COMPLETED_NONE_FUTURE
if
non_block
else
None
def
sample_tokens
(
# type: ignore[override]
self
,
grammar_output
:
"GrammarOutput | None"
,
non_block
:
bool
=
False
,
)
->
ModelRunnerOutput
|
Future
[
ModelRunnerOutput
]:
"""Execute the model on the Ray workers.
The scheduler output to use should have been provided in
a prior call to execute_model().
Args:
schedule
r_output: The s
cheduler output to execut
e.
gramma
r_output: The s
tructured outputs grammar bitmask, if applicabl
e.
non_block: If True, the method will return a Future.
Returns:
The model runner output.
"""
scheduler_output
=
self
.
scheduler_output
if
scheduler_output
is
None
:
return
None
# noqa
self
.
scheduler_output
=
None
# Build the compiled DAG for the first time.
if
self
.
forward_dag
is
None
:
# type: ignore
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
refs
=
self
.
forward_dag
.
execute
(
scheduler_output
)
# type: ignore
refs
=
self
.
forward_dag
.
execute
(
(
scheduler_output
,
grammar_output
)
)
# type: ignore
if
not
self
.
has_connector
:
# Get output only from a single worker (output_rank)
...
...
vllm/v1/executor/ray_utils.py
View file @
0cdbe7b7
...
...
@@ -19,7 +19,7 @@ from vllm.v1.outputs import AsyncModelRunnerOutput
from
vllm.v1.worker.worker_base
import
WorkerWrapperBase
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
logger
=
init_logger
(
__name__
)
...
...
@@ -82,36 +82,41 @@ try:
def
execute_model_ray
(
self
,
scheduler_output
:
Union
[
"SchedulerOutput"
,
tuple
[
"SchedulerOutput"
,
"IntermediateTensors"
]
],
execute_model_input
:
tuple
[
"SchedulerOutput"
,
"GrammarOutput"
]
|
tuple
[
"SchedulerOutput"
,
"GrammarOutput"
,
"IntermediateTensors"
],
)
->
Union
[
"ModelRunnerOutput"
,
tuple
[
"SchedulerOutput"
,
"IntermediateTensors"
]
"ModelRunnerOutput"
,
tuple
[
"SchedulerOutput"
,
"GrammarOutput"
,
"IntermediateTensors"
],
]:
# This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary()
self
.
setup_device_if_necessary
()
assert
self
.
worker
is
not
None
,
"Worker is not initialized"
if
isinstance
(
scheduler_output
,
tuple
):
scheduler_output
,
intermediate_tensors
=
scheduler_output
if
len
(
execute_model_input
)
==
3
:
scheduler_output
,
grammar_output
,
intermediate_tensors
=
(
execute_model_input
)
else
:
scheduler_output
,
intermediate_tensors
=
scheduler_output
,
None
scheduler_output
,
grammar_output
=
execute_model_input
intermediate_tensors
=
None
assert
self
.
worker
.
model_runner
is
not
None
output
=
self
.
worker
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
if
isinstance
(
output
,
IntermediateTensors
):
output
=
scheduler_output
,
output
output
=
scheduler_output
,
grammar_output
,
output
elif
not
get_pp_group
().
is_last_rank
:
# Case where there are no scheduled requests
# but may still be finished requests.
assert
not
output
or
not
output
.
req_ids
output
=
scheduler_output
,
None
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
if
isinstance
(
output
,
AsyncModelRunnerOutput
):
output
=
output
.
get_output
()
output
=
scheduler_output
,
grammar_output
,
None
elif
output
is
None
:
output
=
self
.
worker
.
model_runner
.
sample_tokens
(
grammar_output
)
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
if
isinstance
(
output
,
AsyncModelRunnerOutput
):
output
=
output
.
get_output
()
return
output
def
override_env_vars
(
self
,
vars
:
dict
[
str
,
str
]):
...
...
vllm/v1/structured_output/utils.py
View file @
0cdbe7b7
...
...
@@ -16,6 +16,7 @@ from diskcache import Cache
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.utils.import_utils
import
LazyLoader
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
if
TYPE_CHECKING
:
import
outlines_core
as
oc
...
...
@@ -24,7 +25,6 @@ if TYPE_CHECKING:
import
xgrammar
as
xgr
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
else
:
xgr
=
LazyLoader
(
"xgr"
,
globals
(),
"xgrammar"
)
...
...
@@ -47,6 +47,7 @@ CACHE = None
def
apply_grammar_bitmask
(
scheduler_output
:
SchedulerOutput
,
grammar_output
:
GrammarOutput
,
input_batch
:
InputBatch
,
logits
:
torch
.
Tensor
,
)
->
None
:
...
...
@@ -58,9 +59,9 @@ def apply_grammar_bitmask(
input_batch (InputBatch): The input of model runner.
logits (torch.Tensor): The output logits of model forward.
"""
grammar_bitmask
=
scheduler_output
.
grammar_bitmask
if
grammar_bitmask
is
None
:
return
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask
=
grammar_output
.
grammar_bitmask
# We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests.
...
...
@@ -79,7 +80,7 @@ def apply_grammar_bitmask(
cumulative_offset
+=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[])
)
if
req_id
in
schedule
r_output
.
structured_output_request_ids
:
if
req_id
in
gramma
r_output
.
structured_output_request_ids
:
struct_out_req_batch_indices
[
req_id
]
=
logit_index
out_indices
=
[]
...
...
@@ -91,7 +92,7 @@ def apply_grammar_bitmask(
dtype
=
grammar_bitmask
.
dtype
,
)
cumulative_index
=
0
for
req_id
in
schedule
r_output
.
structured_output_request_ids
:
for
req_id
in
gramma
r_output
.
structured_output_request_ids
:
num_spec_tokens
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[])
)
...
...
@@ -101,22 +102,28 @@ def apply_grammar_bitmask(
sorted_bitmask
[
logit_index
+
i
]
=
grammar_bitmask
[
cumulative_index
+
i
]
out_indices
.
append
(
logit_index
+
i
)
cumulative_index
+=
1
+
num_spec_tokens
grammar_bitmask
=
sorted_bitmask
# Copy async to device as tensor.
grammar_bitmask
=
torch
.
from_numpy
(
sorted_bitmask
).
to
(
logits
.
device
,
non_blocking
=
True
)
# If the length of out indices and the logits have the same shape
# we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits.
skip_out_indices
=
len
(
out_indices
)
==
logits
.
shape
[
0
]
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask
=
torch
.
from_numpy
(
grammar_bitmask
).
contiguous
()
index_tensor
=
None
if
not
skip_out_indices
:
# xgrammar expects a python list of indices but it will actually work with
# a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
# manner and there should be no cpu sync within xgrammar.
index_tensor
=
torch
.
tensor
(
out_indices
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
True
)
index_tensor
=
index_tensor
.
to
(
logits
.
device
,
non_blocking
=
True
)
xgr
.
apply_token_bitmask_inplace
(
logits
,
grammar_bitmask
.
to
(
logits
.
device
,
non_blocking
=
True
),
indices
=
out_indices
if
not
skip_out_indices
else
None
,
)
xgr
.
apply_token_bitmask_inplace
(
logits
,
grammar_bitmask
,
indices
=
index_tensor
)
class
OutlinesVocabulary
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment