Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7765e5ba
Unverified
Commit
7765e5ba
authored
Nov 17, 2025
by
Nick Hill
Committed by
GitHub
Nov 17, 2025
Browse files
[BugFix] Fix PP performance and PP kv connector output regression (#28768)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
d8874c61
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
104 deletions
+105
-104
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+66
-84
vllm/v1/executor/ray_executor.py
vllm/v1/executor/ray_executor.py
+20
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+18
-5
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+1
-14
No files found.
vllm/v1/engine/core.py
View file @
7765e5ba
...
@@ -63,7 +63,6 @@ from vllm.v1.outputs import ModelRunnerOutput
...
@@ -63,7 +63,6 @@ from vllm.v1.outputs import ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.utils
import
record_function_or_nullcontext
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -181,11 +180,13 @@ class EngineCore:
...
@@ -181,11 +180,13 @@ class EngineCore:
logger
.
info
(
"Batch queue is enabled with size %d"
,
self
.
batch_queue_size
)
logger
.
info
(
"Batch queue is enabled with size %d"
,
self
.
batch_queue_size
)
self
.
batch_queue
=
deque
(
maxlen
=
self
.
batch_queue_size
)
self
.
batch_queue
=
deque
(
maxlen
=
self
.
batch_queue_size
)
self
.
ec_producer
=
(
vllm_config
.
ec_transfer_config
is
not
None
and
vllm_config
.
ec_transfer_config
.
is_ec_producer
)
self
.
request_block_hasher
:
Callable
[[
Request
],
list
[
BlockHash
]]
|
None
=
None
self
.
request_block_hasher
:
Callable
[[
Request
],
list
[
BlockHash
]]
|
None
=
None
if
(
if
vllm_config
.
cache_config
.
enable_prefix_caching
or
kv_connector
is
not
None
:
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
or
kv_connector
is
not
None
):
caching_hash_fn
=
get_hash_fn_by_name
(
caching_hash_fn
=
get_hash_fn_by_name
(
vllm_config
.
cache_config
.
prefix_caching_hash_algo
vllm_config
.
cache_config
.
prefix_caching_hash_algo
)
)
...
@@ -246,7 +247,7 @@ class EngineCore:
...
@@ -246,7 +247,7 @@ class EngineCore:
elapsed
=
time
.
time
()
-
start
elapsed
=
time
.
time
()
-
start
logger
.
info_once
(
logger
.
info_once
(
(
"init engine (profile, create kv cache, warmup model) took %.2f seconds"
)
,
"init engine (profile, create kv cache, warmup model) took %.2f seconds"
,
elapsed
,
elapsed
,
scope
=
"local"
,
scope
=
"local"
,
)
)
...
@@ -312,6 +313,16 @@ class EngineCore:
...
@@ -312,6 +313,16 @@ class EngineCore:
)
)
raise
err
raise
err
def
_log_err_callback
(
self
,
scheduler_output
:
SchedulerOutput
):
"""Log error details of a future that's not expected to return a result."""
def
callback
(
f
,
sched_output
=
scheduler_output
):
with
self
.
log_error_detail
(
sched_output
):
result
=
f
.
result
()
assert
result
is
None
return
callback
def
step
(
self
)
->
tuple
[
dict
[
int
,
EngineCoreOutputs
],
bool
]:
def
step
(
self
)
->
tuple
[
dict
[
int
,
EngineCoreOutputs
],
bool
]:
"""Schedule, execute, and make output.
"""Schedule, execute, and make output.
...
@@ -323,10 +334,7 @@ class EngineCore:
...
@@ -323,10 +334,7 @@ class EngineCore:
# or finished and not yet removed from the batch.
# or finished and not yet removed from the batch.
if
not
self
.
scheduler
.
has_requests
():
if
not
self
.
scheduler
.
has_requests
():
return
{},
False
return
{},
False
with
record_function_or_nullcontext
(
"core step: schedule"
):
scheduler_output
=
self
.
scheduler
.
schedule
()
scheduler_output
=
self
.
scheduler
.
schedule
()
with
record_function_or_nullcontext
(
"core step: execute_model"
):
future
=
self
.
model_executor
.
execute_model
(
scheduler_output
,
non_block
=
True
)
future
=
self
.
model_executor
.
execute_model
(
scheduler_output
,
non_block
=
True
)
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
scheduler_output
)
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
scheduler_output
)
with
self
.
log_error_detail
(
scheduler_output
):
with
self
.
log_error_detail
(
scheduler_output
):
...
@@ -334,7 +342,6 @@ class EngineCore:
...
@@ -334,7 +342,6 @@ class EngineCore:
if
model_output
is
None
:
if
model_output
is
None
:
model_output
=
self
.
model_executor
.
sample_tokens
(
grammar_output
)
model_output
=
self
.
model_executor
.
sample_tokens
(
grammar_output
)
with
record_function_or_nullcontext
(
"core step: update_from_output"
):
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
scheduler_output
,
model_output
)
)
...
@@ -378,52 +385,34 @@ class EngineCore:
...
@@ -378,52 +385,34 @@ class EngineCore:
model_executed
=
False
model_executed
=
False
deferred_scheduler_output
=
None
deferred_scheduler_output
=
None
if
self
.
scheduler
.
has_requests
():
if
self
.
scheduler
.
has_requests
():
with
record_function_or_nullcontext
(
"core step_with_batch_queue: schedule"
):
scheduler_output
=
self
.
scheduler
.
schedule
()
scheduler_output
=
self
.
scheduler
.
schedule
()
with
record_function_or_nullcontext
(
"core step_with_batch_queue: execute_model"
):
exec_future
=
self
.
model_executor
.
execute_model
(
exec_future
=
self
.
model_executor
.
execute_model
(
scheduler_output
,
non_block
=
True
scheduler_output
,
non_block
=
True
)
)
if
not
self
.
ec_producer
:
model_executed
=
scheduler_output
.
total_num_scheduled_tokens
>
0
model_executed
=
scheduler_output
.
total_num_scheduled_tokens
>
0
if
scheduler_output
.
pending_structured_output_tokens
:
if
not
model_executed
:
with
record_function_or_nullcontext
(
# No sampling required (no requests scheduled).
"core step_with_batch_queue: pending_structured_output_tokens"
future
=
cast
(
Future
[
ModelRunnerOutput
],
exec_future
)
):
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output
=
scheduler_output
# Block-wait for execute to return
# (continues running async on the GPU).
with
self
.
log_error_detail
(
scheduler_output
):
exec_result
=
exec_future
.
result
()
assert
exec_result
is
None
else
:
else
:
with
record_function_or_nullcontext
(
exec_future
.
add_done_callback
(
self
.
_log_err_callback
(
scheduler_output
))
"core step_with_batch_queue: get_grammar_bitmask"
)
:
if
not
scheduler_output
.
pending_structured_output_tokens
:
# We aren't waiting for any tokens, get any grammar
# We aren't waiting for any tokens, get any grammar
output
#
output
immediately.
#
and sample
immediately.
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
scheduler_output
scheduler_output
)
)
# Block-wait for execute to return (continues running async on the GPU).
with
self
.
log_error_detail
(
scheduler_output
):
exec_result
=
exec_future
.
result
()
if
exec_result
is
None
:
with
record_function_or_nullcontext
(
"core step_with_batch_queue: sample_tokens"
):
# Call sample tokens.
future
=
self
.
model_executor
.
sample_tokens
(
future
=
self
.
model_executor
.
sample_tokens
(
grammar_output
,
non_block
=
True
grammar_output
,
non_block
=
True
)
)
else
:
else
:
# No sampling required (e.g. all requests finished).
# We need to defer sampling until we have processed the model output
future
=
cast
(
Future
[
ModelRunnerOutput
],
exec_future
)
# from the prior step.
deferred_scheduler_output
=
scheduler_output
if
not
deferred_scheduler_output
:
# Add this step's future to the queue.
# Add this step's future to the queue.
batch_queue
.
appendleft
((
future
,
scheduler_output
))
batch_queue
.
appendleft
((
future
,
scheduler_output
))
if
(
if
(
...
@@ -440,14 +429,12 @@ class EngineCore:
...
@@ -440,14 +429,12 @@ class EngineCore:
# only be called when the scheduler contains requests or the queue
# only be called when the scheduler contains requests or the queue
# is non-empty.
# is non-empty.
return
None
,
False
return
None
,
False
with
record_function_or_nullcontext
(
"core step_with_batch_queue: model_output"
):
# Block until the next result is available.
# Block until the next result is available.
future
,
scheduler_output
=
batch_queue
.
pop
()
future
,
scheduler_output
=
batch_queue
.
pop
()
with
self
.
log_error_detail
(
scheduler_output
):
with
self
.
log_error_detail
(
scheduler_output
):
model_output
=
future
.
result
()
model_output
=
future
.
result
()
with
record_function_or_nullcontext
(
"core step_with_batch_queue: update_from_output"
):
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
scheduler_output
,
model_output
)
)
...
@@ -456,17 +443,12 @@ class EngineCore:
...
@@ -456,17 +443,12 @@ class EngineCore:
# in a field and do it immediately once step_with_batch_queue is
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if
deferred_scheduler_output
:
if
deferred_scheduler_output
:
with
record_function_or_nullcontext
(
"core step_with_batch_queue: deferred_scheduler_output"
):
# We now have the tokens needed to compute the bitmask for the
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
# deferred request. Get the bitmask and call sample tokens.
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
deferred_scheduler_output
deferred_scheduler_output
)
)
future
=
self
.
model_executor
.
sample_tokens
(
future
=
self
.
model_executor
.
sample_tokens
(
grammar_output
,
non_block
=
True
)
grammar_output
,
non_block
=
True
)
batch_queue
.
appendleft
((
future
,
deferred_scheduler_output
))
batch_queue
.
appendleft
((
future
,
deferred_scheduler_output
))
return
engine_core_outputs
,
model_executed
return
engine_core_outputs
,
model_executed
...
...
vllm/v1/executor/ray_executor.py
View file @
7765e5ba
...
@@ -99,6 +99,11 @@ class RayDistributedExecutor(Executor):
...
@@ -99,6 +99,11 @@ class RayDistributedExecutor(Executor):
# KV connector setup
# KV connector setup
self
.
has_connector
=
self
.
vllm_config
.
kv_transfer_config
is
not
None
self
.
has_connector
=
self
.
vllm_config
.
kv_transfer_config
is
not
None
self
.
ec_producer
=
(
self
.
vllm_config
.
ec_transfer_config
is
not
None
and
self
.
vllm_config
.
ec_transfer_config
.
is_ec_producer
)
self
.
scheduler_output
:
SchedulerOutput
|
None
=
None
self
.
scheduler_output
:
SchedulerOutput
|
None
=
None
@
property
@
property
...
@@ -395,6 +400,12 @@ class RayDistributedExecutor(Executor):
...
@@ -395,6 +400,12 @@ class RayDistributedExecutor(Executor):
"State error: sample_tokens() must be called "
"State error: sample_tokens() must be called "
"after execute_model() returns None."
"after execute_model() returns None."
)
)
if
self
.
ec_producer
or
not
scheduler_output
.
total_num_scheduled_tokens
:
# Model will not execute, call model runner immediately.
return
self
.
_execute_dag
(
scheduler_output
,
None
,
non_block
)
# Model will execute, defer to sample_tokens() call.
self
.
scheduler_output
=
scheduler_output
self
.
scheduler_output
=
scheduler_output
return
COMPLETED_NONE_FUTURE
if
non_block
else
None
return
COMPLETED_NONE_FUTURE
if
non_block
else
None
...
@@ -417,10 +428,18 @@ class RayDistributedExecutor(Executor):
...
@@ -417,10 +428,18 @@ class RayDistributedExecutor(Executor):
"""
"""
scheduler_output
=
self
.
scheduler_output
scheduler_output
=
self
.
scheduler_output
if
scheduler_output
is
None
:
if
scheduler_output
is
None
:
return
None
# noqa
return
COMPLETED_NONE_FUTURE
if
non_block
else
None
# noqa
self
.
scheduler_output
=
None
self
.
scheduler_output
=
None
return
self
.
_execute_dag
(
scheduler_output
,
grammar_output
,
non_block
)
def
_execute_dag
(
self
,
scheduler_output
:
SchedulerOutput
,
grammar_output
:
"GrammarOutput | None"
,
non_block
:
bool
=
False
,
)
->
ModelRunnerOutput
|
Future
[
ModelRunnerOutput
]:
# Build the compiled DAG for the first time.
# Build the compiled DAG for the first time.
if
self
.
forward_dag
is
None
:
# type: ignore
if
self
.
forward_dag
is
None
:
# type: ignore
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
7765e5ba
...
@@ -7,7 +7,7 @@ import time
...
@@ -7,7 +7,7 @@ import time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Iterator
from
collections.abc
import
Iterator
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
copy
import
deepcopy
from
copy
import
copy
,
deepcopy
from
functools
import
reduce
from
functools
import
reduce
from
itertools
import
product
from
itertools
import
product
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
TypeAlias
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
TypeAlias
,
cast
...
@@ -250,7 +250,6 @@ class ExecuteModelState(NamedTuple):
...
@@ -250,7 +250,6 @@ class ExecuteModelState(NamedTuple):
hidden_states
:
torch
.
Tensor
hidden_states
:
torch
.
Tensor
sample_hidden_states
:
torch
.
Tensor
sample_hidden_states
:
torch
.
Tensor
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
kv_connector_output
:
KVConnectorOutput
|
None
ec_connector_output
:
ECConnectorOutput
|
None
ec_connector_output
:
ECConnectorOutput
|
None
...
@@ -573,6 +572,7 @@ class GPUModelRunner(
...
@@ -573,6 +572,7 @@ class GPUModelRunner(
# Ephemeral state transferred between execute_model() and sample_tokens().
# Ephemeral state transferred between execute_model() and sample_tokens().
self
.
execute_model_state
:
ExecuteModelState
|
None
=
None
self
.
execute_model_state
:
ExecuteModelState
|
None
=
None
self
.
kv_connector_output
:
KVConnectorOutput
|
None
=
None
def
reset_mm_cache
(
self
)
->
None
:
def
reset_mm_cache
(
self
)
->
None
:
if
self
.
mm_budget
:
if
self
.
mm_budget
:
...
@@ -2803,6 +2803,7 @@ class GPUModelRunner(
...
@@ -2803,6 +2803,7 @@ class GPUModelRunner(
# Return the intermediate tensors.
# Return the intermediate tensors.
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
hidden_states
.
kv_connector_output
=
kv_connector_output
hidden_states
.
kv_connector_output
=
kv_connector_output
self
.
kv_connector_output
=
kv_connector_output
return
hidden_states
return
hidden_states
if
self
.
is_pooling_model
:
if
self
.
is_pooling_model
:
...
@@ -2853,19 +2854,32 @@ class GPUModelRunner(
...
@@ -2853,19 +2854,32 @@ class GPUModelRunner(
hidden_states
,
hidden_states
,
sample_hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
aux_hidden_states
,
kv_connector_output
,
ec_connector_output
,
ec_connector_output
,
)
)
self
.
kv_connector_output
=
kv_connector_output
return
None
return
None
@
torch
.
inference_mode
@
torch
.
inference_mode
def
sample_tokens
(
def
sample_tokens
(
self
,
grammar_output
:
"GrammarOutput | None"
self
,
grammar_output
:
"GrammarOutput | None"
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
IntermediateTensors
:
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
IntermediateTensors
:
kv_connector_output
=
self
.
kv_connector_output
self
.
kv_connector_output
=
None
if
self
.
execute_model_state
is
None
:
if
self
.
execute_model_state
is
None
:
# Nothing to do (PP non-final rank case), output isn't used.
# Nothing to do (PP non-final rank case), output isn't used.
if
not
kv_connector_output
:
return
None
# noqa
return
None
# noqa
# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if
kv_connector_output
.
is_empty
():
return
EMPTY_MODEL_RUNNER_OUTPUT
output
=
copy
(
EMPTY_MODEL_RUNNER_OUTPUT
)
output
.
kv_connector_output
=
kv_connector_output
return
output
# Unpack ephemeral state.
# Unpack ephemeral state.
(
(
scheduler_output
,
scheduler_output
,
...
@@ -2875,7 +2889,6 @@ class GPUModelRunner(
...
@@ -2875,7 +2889,6 @@ class GPUModelRunner(
hidden_states
,
hidden_states
,
sample_hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
aux_hidden_states
,
kv_connector_output
,
ec_connector_output
,
ec_connector_output
,
)
=
self
.
execute_model_state
)
=
self
.
execute_model_state
# Clear ephemeral state.
# Clear ephemeral state.
...
...
vllm/v1/worker/gpu_worker.py
View file @
7765e5ba
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
"""A GPU worker class."""
import
copy
import
gc
import
gc
import
os
import
os
from
contextlib
import
AbstractContextManager
,
nullcontext
from
contextlib
import
AbstractContextManager
,
nullcontext
...
@@ -45,7 +44,6 @@ from vllm.v1.core.sched.output import GrammarOutput
...
@@ -45,7 +44,6 @@ from vllm.v1.core.sched.output import GrammarOutput
from
vllm.v1.engine
import
ReconfigureDistributedRequest
,
ReconfigureRankType
from
vllm.v1.engine
import
ReconfigureDistributedRequest
,
ReconfigureRankType
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
(
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
AsyncModelRunnerOutput
,
AsyncModelRunnerOutput
,
DraftTokenIds
,
DraftTokenIds
,
ModelRunnerOutput
,
ModelRunnerOutput
,
...
@@ -581,19 +579,8 @@ class Worker(WorkerBase):
...
@@ -581,19 +579,8 @@ class Worker(WorkerBase):
all_gather_tensors
=
all_gather_tensors
,
all_gather_tensors
=
all_gather_tensors
,
)
)
kv_connector_output
=
output
.
kv_connector_output
if
not
kv_connector_output
:
return
None
return
None
# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if
kv_connector_output
.
is_empty
():
return
EMPTY_MODEL_RUNNER_OUTPUT
output
=
copy
.
copy
(
EMPTY_MODEL_RUNNER_OUTPUT
)
output
.
kv_connector_output
=
kv_connector_output
return
output
def
take_draft_token_ids
(
self
)
->
DraftTokenIds
|
None
:
def
take_draft_token_ids
(
self
)
->
DraftTokenIds
|
None
:
return
self
.
model_runner
.
take_draft_token_ids
()
return
self
.
model_runner
.
take_draft_token_ids
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment