Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
289eb6c5
Unverified
Commit
289eb6c5
authored
Nov 09, 2025
by
Nick Hill
Committed by
GitHub
Nov 09, 2025
Browse files
[Core] Simplify async KV output aggregation (#28327)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
19d91ece
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
153 deletions
+45
-153
tests/v1/executor/test_executor.py
tests/v1/executor/test_executor.py
+9
-1
tests/v1/kv_connector/unit/test_output_aggregator.py
tests/v1/kv_connector/unit/test_output_aggregator.py
+0
-69
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+0
-40
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+36
-43
No files found.
tests/v1/executor/test_executor.py
View file @
289eb6c5
...
@@ -9,6 +9,7 @@ from typing import Any
...
@@ -9,6 +9,7 @@ from typing import Any
import
pytest
import
pytest
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.async_llm
import
AsyncLLM
...
@@ -28,12 +29,19 @@ class CustomMultiprocExecutor(MultiprocExecutor):
...
@@ -28,12 +29,19 @@ class CustomMultiprocExecutor(MultiprocExecutor):
kwargs
:
dict
|
None
=
None
,
kwargs
:
dict
|
None
=
None
,
non_block
:
bool
=
False
,
non_block
:
bool
=
False
,
unique_reply_rank
:
int
|
None
=
None
,
unique_reply_rank
:
int
|
None
=
None
,
kv_output_aggregator
:
KVOutputAggregator
=
None
,
)
->
Any
|
list
[
Any
]
|
Future
[
Any
|
list
[
Any
]]:
)
->
Any
|
list
[
Any
]
|
Future
[
Any
|
list
[
Any
]]:
# Drop marker to show that this was run
# Drop marker to show that this was run
with
open
(
".marker"
,
"w"
):
with
open
(
".marker"
,
"w"
):
...
...
return
super
().
collective_rpc
(
return
super
().
collective_rpc
(
method
,
timeout
,
args
,
kwargs
,
non_block
,
unique_reply_rank
method
,
timeout
,
args
,
kwargs
,
non_block
,
unique_reply_rank
,
kv_output_aggregator
,
)
)
...
...
tests/v1/kv_connector/unit/test_output_aggregator.py
View file @
289eb6c5
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
concurrent.futures
import
Future
import
pytest
import
pytest
...
@@ -86,74 +85,6 @@ def test_aggregate_workers_output():
...
@@ -86,74 +85,6 @@ def test_aggregate_workers_output():
assert
aggregated
.
invalid_block_ids
==
{
3
,
4
,
5
}
assert
aggregated
.
invalid_block_ids
==
{
3
,
4
,
5
}
def
test_async_aggregate_workers_output
():
aggregator
=
KVOutputAggregator
(
expected_finished_count
=
2
)
future
:
Future
[
list
[
DummyModelRunnerOutput
]]
=
Future
()
result_future
=
aggregator
.
async_aggregate
(
future
)
output1
=
DummyModelRunnerOutput
()
output2
=
DummyModelRunnerOutput
()
future
.
set_result
([
output1
,
output2
])
assert
result_future
.
done
()
aggregated
=
result_future
.
result
()
assert
aggregated
is
output1
aggregated
=
aggregated
.
kv_connector_output
assert
aggregated
.
finished_sending
is
None
assert
aggregated
.
finished_recving
is
None
assert
not
aggregated
.
invalid_block_ids
future
=
Future
()
result_future
=
aggregator
.
async_aggregate
(
future
)
output1
=
DummyModelRunnerOutput
(
finished_sending
=
{
"req1"
},
finished_recving
=
{
"req2"
}
)
output2
=
DummyModelRunnerOutput
(
invalid_block_ids
=
{
1
})
future
.
set_result
([
output1
,
output2
])
assert
result_future
.
done
()
aggregated
=
result_future
.
result
()
assert
aggregated
is
output1
aggregated
=
aggregated
.
kv_connector_output
assert
aggregated
.
finished_sending
is
None
assert
aggregated
.
finished_recving
is
None
assert
aggregated
.
invalid_block_ids
==
{
1
}
future
=
Future
()
result_future
=
aggregator
.
async_aggregate
(
future
)
output1
=
DummyModelRunnerOutput
(
invalid_block_ids
=
{
2
})
output2
=
DummyModelRunnerOutput
(
finished_sending
=
{
"req1"
})
future
.
set_result
([
output1
,
output2
])
assert
result_future
.
done
()
aggregated
=
result_future
.
result
()
assert
aggregated
is
output1
aggregated
=
aggregated
.
kv_connector_output
assert
aggregated
.
finished_sending
==
{
"req1"
}
assert
aggregated
.
finished_recving
is
None
assert
aggregated
.
invalid_block_ids
==
{
2
}
future
=
Future
()
result_future
=
aggregator
.
async_aggregate
(
future
)
output1
=
DummyModelRunnerOutput
(
invalid_block_ids
=
{
3
,
4
})
output2
=
DummyModelRunnerOutput
(
finished_recving
=
{
"req2"
},
invalid_block_ids
=
{
4
,
5
}
)
future
.
set_result
([
output1
,
output2
])
assert
result_future
.
done
()
aggregated
=
result_future
.
result
()
assert
aggregated
is
output1
aggregated
=
aggregated
.
kv_connector_output
assert
aggregated
.
finished_sending
is
None
assert
aggregated
.
finished_recving
==
{
"req2"
}
assert
aggregated
.
invalid_block_ids
==
{
3
,
4
,
5
}
def
test_aggregate_workers_output_with_expected_finished_count
():
def
test_aggregate_workers_output_with_expected_finished_count
():
# We create the aggregator expecting to collect from 4 workers
# We create the aggregator expecting to collect from 4 workers
aggregator
=
KVOutputAggregator
(
expected_finished_count
=
4
)
aggregator
=
KVOutputAggregator
(
expected_finished_count
=
4
)
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
289eb6c5
...
@@ -4,9 +4,6 @@
...
@@ -4,9 +4,6 @@
KV cache helper for store.
KV cache helper for store.
"""
"""
import
contextlib
from
collections.abc
import
Sequence
from
concurrent.futures
import
CancelledError
,
Future
from
typing
import
TYPE_CHECKING
,
Literal
from
typing
import
TYPE_CHECKING
,
Literal
import
torch
import
torch
...
@@ -220,43 +217,6 @@ class KVOutputAggregator:
...
@@ -220,43 +217,6 @@ class KVOutputAggregator:
return
output
return
output
def
async_aggregate
(
self
,
output_future
:
Future
[
Sequence
[
ModelRunnerOutput
|
None
]],
output_rank
:
int
=
0
,
)
->
Future
[
ModelRunnerOutput
|
None
]:
"""Takes a future that resolves to a list of outputs and returns a future
which resolves to a single aggregated output."""
result_future
:
Future
[
ModelRunnerOutput
|
None
]
=
Future
()
def
callback
(
fut
):
if
result_future
.
done
():
return
try
:
result_future
.
set_result
(
self
.
aggregate
(
fut
.
result
(),
output_rank
))
except
CancelledError
:
result_future
.
cancel
()
except
Exception
as
e
:
result_future
.
set_exception
(
e
)
output_future
.
add_done_callback
(
callback
)
from
vllm.v1.executor.multiproc_executor
import
FutureWrapper
if
isinstance
(
output_future
,
FutureWrapper
):
# Due to the threadless implementation of multiproc FutureWrapper,
# we must block on the delegate future's result() method.
delegate_result
=
result_future
.
result
def
result
(
timeout
=
None
):
with
contextlib
.
suppress
(
Exception
):
output_future
.
result
(
timeout
=
timeout
)
return
delegate_result
()
result_future
.
result
=
result
# type: ignore[method-assign]
return
result_future
def
_make_src_and_dst_indices
(
def
_make_src_and_dst_indices
(
src_block_ids
:
list
[
int
],
src_block_ids
:
list
[
int
],
...
...
vllm/v1/executor/multiproc_executor.py
View file @
289eb6c5
...
@@ -29,6 +29,7 @@ import vllm.envs as envs
...
@@ -29,6 +29,7 @@ import vllm.envs as envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
destroy_distributed_environment
,
destroy_model_parallel
from
vllm.distributed
import
destroy_distributed_environment
,
destroy_model_parallel
from
vllm.distributed.device_communicators.shm_broadcast
import
Handle
,
MessageQueue
from
vllm.distributed.device_communicators.shm_broadcast
import
Handle
,
MessageQueue
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
get_dp_group
,
get_dp_group
,
get_ep_group
,
get_ep_group
,
...
@@ -57,8 +58,13 @@ logger = init_logger(__name__)
...
@@ -57,8 +58,13 @@ logger = init_logger(__name__)
class
FutureWrapper
(
Future
):
class
FutureWrapper
(
Future
):
def
__init__
(
self
,
futures_queue
:
deque
[
tuple
[
"FutureWrapper"
,
Callable
]]):
def
__init__
(
self
,
futures_queue
:
deque
[
tuple
[
"FutureWrapper"
,
Callable
]],
aggregate
:
Callable
=
lambda
x
:
x
,
):
self
.
futures_queue
=
futures_queue
self
.
futures_queue
=
futures_queue
self
.
aggregate
=
aggregate
super
().
__init__
()
super
().
__init__
()
def
result
(
self
,
timeout
=
None
):
def
result
(
self
,
timeout
=
None
):
...
@@ -72,7 +78,7 @@ class FutureWrapper(Future):
...
@@ -72,7 +78,7 @@ class FutureWrapper(Future):
def
wait_for_response
(
self
,
get_response
:
Callable
):
def
wait_for_response
(
self
,
get_response
:
Callable
):
try
:
try
:
response
=
get_response
()
response
=
self
.
aggregate
(
get_response
()
)
with
suppress
(
InvalidStateError
):
with
suppress
(
InvalidStateError
):
self
.
set_result
(
response
)
self
.
set_result
(
response
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -160,7 +166,6 @@ class MultiprocExecutor(Executor):
...
@@ -160,7 +166,6 @@ class MultiprocExecutor(Executor):
self
.
futures_queue
=
deque
[
tuple
[
FutureWrapper
,
Callable
]]()
self
.
futures_queue
=
deque
[
tuple
[
FutureWrapper
,
Callable
]]()
self
.
output_rank
=
self
.
_get_output_rank
()
self
.
output_rank
=
self
.
_get_output_rank
()
self
.
has_connector
=
self
.
vllm_config
.
kv_transfer_config
is
not
None
def
start_worker_monitor
(
self
):
def
start_worker_monitor
(
self
):
workers
=
self
.
workers
workers
=
self
.
workers
...
@@ -199,44 +204,27 @@ class MultiprocExecutor(Executor):
...
@@ -199,44 +204,27 @@ class MultiprocExecutor(Executor):
def
execute_model
(
# type: ignore[override]
def
execute_model
(
# type: ignore[override]
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
self
,
scheduler_output
:
SchedulerOutput
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
None
|
Future
[
ModelRunnerOutput
|
None
]:
)
->
ModelRunnerOutput
|
None
|
Future
[
ModelRunnerOutput
|
None
]:
return
self
.
_execute_with_aggregation
(
return
self
.
collective_rpc
(
"execute_model"
,
scheduler_output
,
non_block
=
non_block
"execute_model"
,
args
=
(
scheduler_output
,),
unique_reply_rank
=
self
.
output_rank
,
non_block
=
non_block
,
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
,
kv_output_aggregator
=
self
.
kv_output_aggregator
,
)
)
def
sample_tokens
(
# type: ignore[override]
def
sample_tokens
(
# type: ignore[override]
self
,
grammar_output
:
GrammarOutput
|
None
,
non_block
:
bool
=
False
self
,
grammar_output
:
GrammarOutput
|
None
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
Future
[
ModelRunnerOutput
]:
)
->
ModelRunnerOutput
|
Future
[
ModelRunnerOutput
]:
return
self
.
_execute_with_aggregation
(
# type: ignore[return-value]
return
self
.
collective_rpc
(
"sample_tokens"
,
grammar_output
,
non_block
=
non_block
"sample_tokens"
,
)
args
=
(
grammar_output
,),
unique_reply_rank
=
self
.
output_rank
,
def
_execute_with_aggregation
(
self
,
method
:
str
,
*
args
,
non_block
:
bool
=
False
)
->
ModelRunnerOutput
|
None
|
Future
[
ModelRunnerOutput
|
None
]:
if
not
self
.
has_connector
:
# get output only from a single worker (output_rank)
return
self
.
collective_rpc
(
method
,
args
=
args
,
unique_reply_rank
=
self
.
output_rank
,
non_block
=
non_block
,
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
,
)
# get output from all workers
outputs
=
self
.
collective_rpc
(
method
,
args
=
args
,
non_block
=
non_block
,
non_block
=
non_block
,
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
,
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
,
kv_output_aggregator
=
self
.
kv_output_aggregator
,
)
)
# aggregate all workers output to a single output
assert
self
.
kv_output_aggregator
is
not
None
if
non_block
:
return
self
.
kv_output_aggregator
.
async_aggregate
(
outputs
,
self
.
output_rank
)
return
self
.
kv_output_aggregator
.
aggregate
(
outputs
,
self
.
output_rank
)
def
execute_dummy_batch
(
self
)
->
None
:
def
execute_dummy_batch
(
self
)
->
None
:
self
.
collective_rpc
(
"execute_dummy_batch"
,
unique_reply_rank
=
self
.
output_rank
)
self
.
collective_rpc
(
"execute_dummy_batch"
,
unique_reply_rank
=
self
.
output_rank
)
...
@@ -254,8 +242,10 @@ class MultiprocExecutor(Executor):
...
@@ -254,8 +242,10 @@ class MultiprocExecutor(Executor):
kwargs
:
dict
|
None
=
None
,
kwargs
:
dict
|
None
=
None
,
non_block
:
bool
=
False
,
non_block
:
bool
=
False
,
unique_reply_rank
:
int
|
None
=
None
,
unique_reply_rank
:
int
|
None
=
None
,
kv_output_aggregator
:
KVOutputAggregator
=
None
,
)
->
Any
|
list
[
Any
]
|
Future
[
Any
|
list
[
Any
]]:
)
->
Any
|
list
[
Any
]
|
Future
[
Any
|
list
[
Any
]]:
"""Returns single result if unique_reply_rank is provided, otherwise list."""
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
is provided, otherwise list."""
if
self
.
is_failed
:
if
self
.
is_failed
:
raise
RuntimeError
(
"Executor failed."
)
raise
RuntimeError
(
"Executor failed."
)
...
@@ -263,20 +253,23 @@ class MultiprocExecutor(Executor):
...
@@ -263,20 +253,23 @@ class MultiprocExecutor(Executor):
deadline
=
None
if
timeout
is
None
else
time
.
monotonic
()
+
timeout
deadline
=
None
if
timeout
is
None
else
time
.
monotonic
()
+
timeout
kwargs
=
kwargs
or
{}
kwargs
=
kwargs
or
{}
# NOTE: If the args are heterogeneous, then we pack them into a list,
if
kv_output_aggregator
is
not
None
:
# and unpack them in the method of every worker, because every worker
output_rank
=
None
# knows their own rank.
aggregate
:
Callable
[[
Any
],
Any
]
=
partial
(
kv_output_aggregator
.
aggregate
,
output_rank
=
unique_reply_rank
or
0
)
else
:
output_rank
=
unique_reply_rank
aggregate
=
lambda
x
:
x
if
isinstance
(
method
,
str
):
if
isinstance
(
method
,
str
):
send_method
=
method
send_method
=
method
else
:
else
:
send_method
=
cloudpickle
.
dumps
(
method
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
send_method
=
cloudpickle
.
dumps
(
method
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
self
.
rpc_broadcast_mq
.
enqueue
((
send_method
,
args
,
kwargs
,
unique_reply
_rank
))
self
.
rpc_broadcast_mq
.
enqueue
((
send_method
,
args
,
kwargs
,
output
_rank
))
workers
=
(
workers
=
(
(
self
.
workers
[
unique_reply_rank
],)
(
self
.
workers
[
output_rank
],)
if
output_rank
is
not
None
else
self
.
workers
if
unique_reply_rank
is
not
None
else
self
.
workers
)
)
shutdown_event
=
self
.
shutdown_event
shutdown_event
=
self
.
shutdown_event
...
@@ -299,10 +292,10 @@ class MultiprocExecutor(Executor):
...
@@ -299,10 +292,10 @@ class MultiprocExecutor(Executor):
" stack trace above for the root cause"
" stack trace above for the root cause"
)
)
responses
.
append
(
result
)
responses
.
append
(
result
)
return
responses
[
0
]
if
unique_reply
_rank
is
not
None
else
responses
return
responses
[
0
]
if
output
_rank
is
not
None
else
responses
if
non_block
:
if
non_block
:
future
=
FutureWrapper
(
self
.
futures_queue
)
future
=
FutureWrapper
(
self
.
futures_queue
,
aggregate
=
aggregate
)
self
.
futures_queue
.
appendleft
((
future
,
get_response
))
self
.
futures_queue
.
appendleft
((
future
,
get_response
))
return
future
return
future
...
@@ -311,7 +304,7 @@ class MultiprocExecutor(Executor):
...
@@ -311,7 +304,7 @@ class MultiprocExecutor(Executor):
future
,
get_fut_response
=
self
.
futures_queue
.
pop
()
future
,
get_fut_response
=
self
.
futures_queue
.
pop
()
future
.
wait_for_response
(
get_fut_response
)
future
.
wait_for_response
(
get_fut_response
)
return
get_response
()
return
aggregate
(
get_response
()
)
@
staticmethod
@
staticmethod
def
_ensure_worker_termination
(
worker_procs
:
list
[
BaseProcess
]):
def
_ensure_worker_termination
(
worker_procs
:
list
[
BaseProcess
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment