Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4dfdb821
Unverified
Commit
4dfdb821
authored
Oct 22, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Oct 22, 2025
Browse files
[P/D] Dynamic `kv_output_aggregator` collect size (#26734)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
58fab50d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
90 additions
and
19 deletions
+90
-19
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+2
-2
tests/v1/kv_connector/unit/test_output_aggregator.py
tests/v1/kv_connector/unit/test_output_aggregator.py
+41
-2
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+30
-6
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+2
-1
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+1
-3
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+7
-4
vllm/v1/outputs.py
vllm/v1/outputs.py
+7
-1
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
4dfdb821
...
@@ -703,7 +703,7 @@ def test_kv_connector_stats_aggregation():
...
@@ -703,7 +703,7 @@ def test_kv_connector_stats_aggregation():
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
# done in MultiprocExecutor.execute_model
# done in MultiprocExecutor.execute_model
aggregator
=
KVOutputAggregator
(
world_size
=
3
)
aggregator
=
KVOutputAggregator
(
expected_finished_count
=
3
)
# Create stats for multiple workers with different transfer patterns
# Create stats for multiple workers with different transfer patterns
worker1_stats
=
NixlKVConnectorStats
()
worker1_stats
=
NixlKVConnectorStats
()
...
@@ -768,7 +768,7 @@ def test_multi_kv_connector_stats_aggregation():
...
@@ -768,7 +768,7 @@ def test_multi_kv_connector_stats_aggregation():
KVOutputAggregator (used by MultiprocExecutor).
KVOutputAggregator (used by MultiprocExecutor).
"""
"""
aggregator
=
KVOutputAggregator
(
world_size
=
3
)
aggregator
=
KVOutputAggregator
(
expected_finished_count
=
3
)
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
...
tests/v1/kv_connector/unit/test_output_aggre
a
gator.py
→
tests/v1/kv_connector/unit/test_output_aggregator.py
View file @
4dfdb821
...
@@ -16,11 +16,13 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
...
@@ -16,11 +16,13 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
finished_sending
:
set
[
str
]
|
None
=
None
,
finished_sending
:
set
[
str
]
|
None
=
None
,
finished_recving
:
set
[
str
]
|
None
=
None
,
finished_recving
:
set
[
str
]
|
None
=
None
,
invalid_block_ids
:
set
[
int
]
|
None
=
None
,
invalid_block_ids
:
set
[
int
]
|
None
=
None
,
expected_finished_count
:
int
=
0
,
):
):
self
.
kv_connector_output
=
KVConnectorOutput
(
self
.
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
finished_sending
,
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
finished_recving
=
finished_recving
,
invalid_block_ids
=
invalid_block_ids
or
set
(),
invalid_block_ids
=
invalid_block_ids
or
set
(),
expected_finished_count
=
expected_finished_count
,
)
)
def
__repr__
(
self
):
def
__repr__
(
self
):
...
@@ -33,7 +35,7 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
...
@@ -33,7 +35,7 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
def
test_aggregate_workers_output
():
def
test_aggregate_workers_output
():
aggregator
=
KVOutputAggregator
(
world_size
=
2
)
aggregator
=
KVOutputAggregator
(
expected_finished_count
=
2
)
output1
=
DummyModelRunnerOutput
()
output1
=
DummyModelRunnerOutput
()
output2
=
DummyModelRunnerOutput
()
output2
=
DummyModelRunnerOutput
()
...
@@ -85,7 +87,7 @@ def test_aggregate_workers_output():
...
@@ -85,7 +87,7 @@ def test_aggregate_workers_output():
def
test_async_aggregate_workers_output
():
def
test_async_aggregate_workers_output
():
aggregator
=
KVOutputAggregator
(
world_size
=
2
)
aggregator
=
KVOutputAggregator
(
expected_finished_count
=
2
)
future1
:
Future
[
DummyModelRunnerOutput
]
=
Future
()
future1
:
Future
[
DummyModelRunnerOutput
]
=
Future
()
future2
:
Future
[
DummyModelRunnerOutput
]
=
Future
()
future2
:
Future
[
DummyModelRunnerOutput
]
=
Future
()
...
@@ -158,3 +160,40 @@ def test_async_aggregate_workers_output():
...
@@ -158,3 +160,40 @@ def test_async_aggregate_workers_output():
assert
aggregated
.
finished_sending
is
None
assert
aggregated
.
finished_sending
is
None
assert
aggregated
.
finished_recving
==
{
"req2"
}
assert
aggregated
.
finished_recving
==
{
"req2"
}
assert
aggregated
.
invalid_block_ids
==
{
3
,
4
,
5
}
assert
aggregated
.
invalid_block_ids
==
{
3
,
4
,
5
}
def
test_aggregate_workers_output_with_expected_finished_count
():
# We create the aggregator expecting to collect from 4 workers
aggregator
=
KVOutputAggregator
(
expected_finished_count
=
4
)
assert
aggregator
.
_expected_finished_count
==
4
# Some request with default expected finished requests
output1
=
DummyModelRunnerOutput
(
finished_sending
=
{
"req1"
})
aggregated
=
aggregator
.
aggregate
([
output1
])
# still expecting to collect from 4 workers
assert
aggregator
.
_send_remaining_count
[
"req1"
]
==
3
assert
not
aggregated
.
kv_connector_output
.
finished_sending
assert
not
aggregated
.
kv_connector_output
.
finished_recving
# Workers discover and find that in this setup they only need to
# collect from 2
output1
=
DummyModelRunnerOutput
(
finished_sending
=
{
"req1"
},
expected_finished_count
=
2
)
output2
=
DummyModelRunnerOutput
(
finished_recving
=
{
"req2"
},
expected_finished_count
=
2
)
output3
=
DummyModelRunnerOutput
(
finished_recving
=
{
"req2"
})
# Req2 only needs 2 acks
aggregated
=
aggregator
.
aggregate
([
output1
,
output2
,
output3
])
assert
aggregated
.
kv_connector_output
.
expected_finished_count
==
2
assert
not
aggregated
.
kv_connector_output
.
finished_sending
# Req2 is finished
assert
"req2"
not
in
aggregator
.
_recv_remaining_count
assert
aggregated
.
kv_connector_output
.
finished_recving
==
{
"req2"
}
# Req1 is still waiting for 2 more acks (expected_finished_count has no effect)
# NOTE: This is to showcase dynamic update. Workers are responsible for
# ensuring "req1" termination in this case
assert
aggregator
.
_send_remaining_count
[
"req1"
]
==
2
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
4dfdb821
...
@@ -4,10 +4,9 @@
...
@@ -4,10 +4,9 @@
KV cache helper for store.
KV cache helper for store.
"""
"""
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
concurrent.futures
import
CancelledError
,
Future
from
concurrent.futures
import
CancelledError
,
Future
from
typing
import
Literal
,
cast
from
typing
import
TYPE_CHECKING
,
Literal
,
cast
import
torch
import
torch
...
@@ -18,6 +17,9 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
...
@@ -18,6 +17,9 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
if
TYPE_CHECKING
:
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -124,11 +126,16 @@ class KVOutputAggregator:
...
@@ -124,11 +126,16 @@ class KVOutputAggregator:
"""Utility class to aggregate the output of all workers into a single
"""Utility class to aggregate the output of all workers into a single
output corresponding to Rank 0 for scheduler."""
output corresponding to Rank 0 for scheduler."""
def
__init__
(
self
,
world_size
:
int
):
def
__init__
(
self
,
expected_finished_count
:
int
):
# Complete transfer tracker. Used to track finished requests
# Complete transfer tracker. Used to track finished requests
# [req_id -> n_remaining_workers]
# [req_id -> n_remaining_workers]
self
.
_recv_remaining_count
=
defaultdict
[
str
,
int
](
lambda
:
world_size
)
self
.
_recv_remaining_count
=
dict
[
str
,
int
]()
self
.
_send_remaining_count
=
defaultdict
[
str
,
int
](
lambda
:
world_size
)
self
.
_send_remaining_count
=
dict
[
str
,
int
]()
self
.
_expected_finished_count
=
expected_finished_count
@
classmethod
def
from_connector
(
cls
,
connector
:
"KVConnectorBase"
,
world_size
:
int
):
return
cls
(
connector
.
get_finished_count
()
or
world_size
)
def
aggregate
(
def
aggregate
(
self
,
outputs
:
list
[
ModelRunnerOutput
],
output_rank
:
int
=
0
self
,
outputs
:
list
[
ModelRunnerOutput
],
output_rank
:
int
=
0
...
@@ -141,7 +148,10 @@ class KVOutputAggregator:
...
@@ -141,7 +148,10 @@ class KVOutputAggregator:
finished_set
:
set
[
str
],
finished_set
:
set
[
str
],
)
->
None
:
)
->
None
:
for
req_id
in
req_ids
or
():
for
req_id
in
req_ids
or
():
remaining_count_dict
[
req_id
]
-=
1
remaining_count
=
remaining_count_dict
.
get
(
req_id
,
self
.
_expected_finished_count
)
remaining_count_dict
[
req_id
]
=
remaining_count
-
1
if
remaining_count_dict
[
req_id
]
==
0
:
if
remaining_count_dict
[
req_id
]
==
0
:
finished_set
.
add
(
req_id
)
finished_set
.
add
(
req_id
)
del
remaining_count_dict
[
req_id
]
del
remaining_count_dict
[
req_id
]
...
@@ -154,6 +164,19 @@ class KVOutputAggregator:
...
@@ -154,6 +164,19 @@ class KVOutputAggregator:
kv_output
=
model_runner_output
.
kv_connector_output
kv_output
=
model_runner_output
.
kv_connector_output
if
not
kv_output
:
if
not
kv_output
:
continue
continue
# Allow the worker to dynamically update the expected number of
# finished sending/recving for new requests.
if
(
kv_output
.
expected_finished_count
>
0
and
kv_output
.
expected_finished_count
!=
self
.
_expected_finished_count
):
logger
.
debug
(
"Expected finished requests updated from %d to %d"
,
self
.
_expected_finished_count
,
kv_output
.
expected_finished_count
,
)
self
.
_expected_finished_count
=
kv_output
.
expected_finished_count
update_finished_set
(
update_finished_set
(
kv_output
.
finished_sending
,
self
.
_send_remaining_count
,
finished_sending
kv_output
.
finished_sending
,
self
.
_send_remaining_count
,
finished_sending
)
)
...
@@ -186,6 +209,7 @@ class KVOutputAggregator:
...
@@ -186,6 +209,7 @@ class KVOutputAggregator:
finished_recving
=
finished_recving
or
None
,
finished_recving
=
finished_recving
or
None
,
kv_connector_stats
=
aggregated_kv_connector_stats
or
None
,
kv_connector_stats
=
aggregated_kv_connector_stats
or
None
,
invalid_block_ids
=
invalid_block_ids
,
invalid_block_ids
=
invalid_block_ids
,
expected_finished_count
=
self
.
_expected_finished_count
,
)
)
return
output
return
output
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
4dfdb821
...
@@ -413,7 +413,8 @@ class KVConnectorBase_V1(ABC):
...
@@ -413,7 +413,8 @@ class KVConnectorBase_V1(ABC):
def
get_finished_count
(
self
)
->
int
|
None
:
def
get_finished_count
(
self
)
->
int
|
None
:
"""
"""
Get the count of requests expected to complete send/receive operations
Get the count of requests expected to complete send/receive operations
via this connector.
via this connector. This method is used to initialize the
KVOutputAggregator, overwriting the default world_size.
Returns:
Returns:
int: expected sending or receiving completion count.
int: expected sending or receiving completion count.
...
...
vllm/v1/engine/core.py
View file @
4dfdb821
...
@@ -160,9 +160,7 @@ class EngineCore:
...
@@ -160,9 +160,7 @@ class EngineCore:
)
)
self
.
use_spec_decode
=
vllm_config
.
speculative_config
is
not
None
self
.
use_spec_decode
=
vllm_config
.
speculative_config
is
not
None
if
self
.
scheduler
.
connector
is
not
None
:
# type: ignore
if
self
.
scheduler
.
connector
is
not
None
:
# type: ignore
self
.
model_executor
.
init_kv_output_aggregator
(
self
.
model_executor
.
init_kv_output_aggregator
(
self
.
scheduler
.
connector
)
# type: ignore
self
.
scheduler
.
connector
.
get_finished_count
()
# type: ignore
)
self
.
mm_registry
=
mm_registry
=
MULTIMODAL_REGISTRY
self
.
mm_registry
=
mm_registry
=
MULTIMODAL_REGISTRY
self
.
mm_receiver_cache
=
engine_receiver_cache_from_config
(
self
.
mm_receiver_cache
=
engine_receiver_cache_from_config
(
...
...
vllm/v1/executor/abstract.py
View file @
4dfdb821
...
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
...
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Literal
,
TypeVar
,
overload
from
typing
import
TYPE_CHECKING
,
Literal
,
TypeVar
,
overload
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
...
@@ -19,6 +19,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
...
@@ -19,6 +19,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.v1.worker.worker_base
import
WorkerBase
if
TYPE_CHECKING
:
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
)
_R
=
TypeVar
(
"_R"
)
...
@@ -233,10 +236,10 @@ class Executor(ABC):
...
@@ -233,10 +236,10 @@ class Executor(ABC):
"""Shutdown the executor."""
"""Shutdown the executor."""
self
.
collective_rpc
(
"shutdown"
)
self
.
collective_rpc
(
"shutdown"
)
def
init_kv_output_aggregator
(
self
,
finished_count
:
int
|
None
)
->
None
:
def
init_kv_output_aggregator
(
self
,
connector
:
"KVConnectorBase"
)
->
None
:
"""Init KVOutputAggregator"""
"""Init KVOutputAggregator"""
self
.
kv_output_aggregator
=
KVOutputAggregator
(
self
.
kv_output_aggregator
=
KVOutputAggregator
.
from_connector
(
finished_count
or
self
.
parallel_config
.
world_size
connect
or
,
self
.
parallel_config
.
world_size
)
)
@
cached_property
# Avoid unnecessary RPC calls
@
cached_property
# Avoid unnecessary RPC calls
...
...
vllm/v1/outputs.py
View file @
4dfdb821
...
@@ -86,8 +86,14 @@ class KVConnectorOutput:
...
@@ -86,8 +86,14 @@ class KVConnectorOutput:
finished_recving
:
set
[
str
]
|
None
=
None
finished_recving
:
set
[
str
]
|
None
=
None
kv_connector_stats
:
KVConnectorStats
|
None
=
None
kv_connector_stats
:
KVConnectorStats
|
None
=
None
# IDs of externally computed KV blocks that failed to load.
# IDs of externally computed KV blocks that failed to load.
# Requests referencing these blocks should be rescheduled to recompute them
.
# Requests referencing these blocks should be rescheduled to recompute them
invalid_block_ids
:
set
[
int
]
=
field
(
default_factory
=
set
)
invalid_block_ids
:
set
[
int
]
=
field
(
default_factory
=
set
)
# Configuration describing how many finished sending/receiving
# notifications should be expected for each request. This allows
# handshake-based connectors like Nixl to update the KVOutputAggregator.
# It captures a static setup info and should almost always remain constant
# for a given connector after discovery. Default value entails no change.
expected_finished_count
:
int
=
0
def
is_empty
(
self
):
def
is_empty
(
self
):
return
(
return
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment