Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aca59674
Unverified
Commit
aca59674
authored
Feb 06, 2026
by
Seiji Eicher
Committed by
GitHub
Feb 06, 2026
Browse files
[KV Connector] Add missing method overrides to MultiConnector (#33292)
Signed-off-by:
Seiji Eicher
<
seiji@anyscale.com
>
parent
67a746e8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
177 additions
and
11 deletions
+177
-11
tests/v1/kv_connector/unit/test_multi_connector.py
tests/v1/kv_connector/unit/test_multi_connector.py
+134
-6
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+43
-5
No files found.
tests/v1/kv_connector/unit/test_multi_connector.py
View file @
aca59674
...
...
@@ -190,27 +190,35 @@ def test_multi_example_connector_consistency():
)
events
=
get_connector_events
()
# get_num_new_matched_tokens and update_state_after_alloc will be called
# on each connector in turn.
assert
events
[
"storage1-SCHEDULER"
][:
3
]
==
[
# First event is set_xfer_handshake_metadata from initialization, then
# get_num_new_matched_tokens and update_state_after_alloc from generate().
assert
events
[
"storage1-SCHEDULER"
][:
4
]
==
[
"set_xfer_handshake_metadata"
,
"get_num_new_matched_tokens 0"
,
"update_state_after_alloc num_blocks=[0] 0"
,
"build_connector_meta"
,
]
assert
events
[
"storage1-WORKER"
][:
5
]
==
[
# First three events are from initialization (register_kv_caches,
# set_host_xfer_buffer_ops, get_handshake_metadata), then generate() events.
assert
events
[
"storage1-WORKER"
][:
7
]
==
[
"register_kv_caches"
,
"set_host_xfer_buffer_ops"
,
"get_handshake_metadata"
,
"bind_connector_metadata"
,
"start_load_kv"
,
"wait_for_layer_load"
,
"save_kv_layer"
,
]
assert
events
[
"storage2-SCHEDULER"
][:
3
]
==
[
assert
events
[
"storage2-SCHEDULER"
][:
4
]
==
[
"set_xfer_handshake_metadata"
,
"get_num_new_matched_tokens 0"
,
"update_state_after_alloc num_blocks=[0] 0"
,
"build_connector_meta"
,
]
assert
events
[
"storage2-WORKER"
][:
5
]
==
[
assert
events
[
"storage2-WORKER"
][:
7
]
==
[
"register_kv_caches"
,
"set_host_xfer_buffer_ops"
,
"get_handshake_metadata"
,
"bind_connector_metadata"
,
"start_load_kv"
,
"wait_for_layer_load"
,
...
...
@@ -297,6 +305,90 @@ def test_engine_id_conflict():
)
def
test_multi_connector_handle_preemptions_integration
():
"""
Integration test: verify MultiConnector delegates handle_preemptions
to all sub-connectors.
Uses TestExampleConnector which logs all method calls to temp files.
This test directly calls handle_preemptions on a MultiConnector with
TestExampleConnector sub-connectors and verifies the calls are logged.
"""
from
tests.v1.kv_connector.unit.utils
import
(
create_scheduler
,
create_vllm_config
,
)
storage_path
=
Path
(
tempfile
.
mkdtemp
())
try
:
# Configure MultiConnector with two TestExampleConnectors
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"MultiConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"connectors"
:
[
{
"kv_connector"
:
"TestExampleConnector"
,
"kv_role"
:
"kv_both"
,
"kv_connector_extra_config"
:
{
"shared_storage_path"
:
str
(
storage_path
/
"s1"
),
"name"
:
"preempt1"
,
},
"kv_connector_module_path"
:
"tests.v1.kv_connector.unit.utils"
,
},
{
"kv_connector"
:
"TestExampleConnector"
,
"kv_role"
:
"kv_both"
,
"kv_connector_extra_config"
:
{
"shared_storage_path"
:
str
(
storage_path
/
"s2"
),
"name"
:
"preempt2"
,
},
"kv_connector_module_path"
:
"tests.v1.kv_connector.unit.utils"
,
},
]
},
)
vllm_config
=
create_vllm_config
(
block_size
=
16
,
max_num_batched_tokens
=
100
,
kv_connector_extra_config
=
kv_transfer_config
.
kv_connector_extra_config
,
)
vllm_config
.
kv_transfer_config
=
kv_transfer_config
# Create scheduler - this initializes the MultiConnector with SCHEDULER role
scheduler
=
create_scheduler
(
vllm_config
,
num_blocks
=
10
)
# Clear any events from initialization
get_connector_events
()
# Directly call handle_preemptions on the scheduler's connector
# Note: handle_preemptions is normally a worker-side method, but we're
# testing the delegation behavior of MultiConnector here.
# The connector attribute contains the KV connector.
assert
scheduler
.
connector
is
not
None
,
"Scheduler should have a connector"
preempted_req_ids
=
{
"req-1"
,
"req-2"
,
"req-3"
}
scheduler
.
connector
.
handle_preemptions
(
preempted_req_ids
)
# Verify both connectors received the handle_preemptions call
events
=
get_connector_events
()
# Both SCHEDULER-role connectors should have logged handle_preemptions
assert
"handle_preemptions"
in
events
.
get
(
"preempt1-SCHEDULER"
,
[]),
(
f
"preempt1-SCHEDULER should have handle_preemptions call. "
f
"Got events:
{
events
}
"
)
assert
"handle_preemptions"
in
events
.
get
(
"preempt2-SCHEDULER"
,
[]),
(
f
"preempt2-SCHEDULER should have handle_preemptions call. "
f
"Got events:
{
events
}
"
)
finally
:
# Cleanup
shutil
.
rmtree
(
storage_path
,
ignore_errors
=
True
)
class
TestMultiConnectorStats
:
"""Tests for MultiConnector stats reconstruction and operations."""
...
...
@@ -647,3 +739,39 @@ class TestMultiConnectorPreferCrossLayerBlocks:
MockConnector
.
__new__
(
MockConnector
),
# default False
]
assert
mc
.
prefer_cross_layer_blocks
is
False
def
test_multi_connector_overrides_all_base_methods
():
"""
Ensure MultiConnector overrides all public methods from KVConnectorBase_V1.
"""
# These are fine to inherit from KVConnectorBase_V1
# TODO(https://github.com/vllm-project/vllm/pull/31811): Remove
# get_kv_connector_kv_cache_events from INHERITED_OK once implemented.
INHERITED_OK
=
{
"role"
,
"has_connector_metadata"
,
"get_kv_connector_kv_cache_events"
,
}
base_members
=
{
name
for
name
in
dir
(
KVConnectorBase_V1
)
if
not
name
.
startswith
(
"_"
)
}
-
KVConnectorBase_V1
.
__abstractmethods__
missing
=
[
name
for
name
in
sorted
(
base_members
)
if
name
not
in
INHERITED_OK
and
name
not
in
MultiConnector
.
__dict__
]
if
missing
:
pytest
.
fail
(
f
"""
MultiConnector does not override these KVConnectorBase_V1 methods:
{
missing
}
MultiConnector wraps other connectors and must delegate all methods.
Please add overrides that delegate to self._connectors.
Options:
1. Add delegation in MultiConnector (preferred)
2. Add to INHERITED_OK if the base implementation works correctly
"""
)
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
aca59674
...
...
@@ -12,7 +12,9 @@ from vllm.config.kv_transfer import KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBaseType
from
vllm.distributed.kv_transfer.kv_connector.factory
import
KVConnectorFactory
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
CopyBlocksOp
,
KVConnectorBase_V1
,
KVConnectorHandshakeMetadata
,
KVConnectorMetadata
,
KVConnectorRole
,
)
...
...
@@ -272,11 +274,26 @@ class MultiConnector(KVConnectorBase_V1):
agg_block_ids
|=
c
.
get_block_ids_with_load_errors
()
return
agg_block_ids
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' method
# for the MultiConnector. It should be able to get events from multiple
# connectors, handling the case where only a subset of the requested connectors
# implements the 'get_kv_connector_kv_cache_events'
# Follow on PR from https://github.com/vllm-project/vllm/pull/28309#pullrequestreview-3566351082
def
set_host_xfer_buffer_ops
(
self
,
copy_operation
:
CopyBlocksOp
):
"""Set xPU-specific copy ops for all sub-connectors."""
for
c
in
self
.
_connectors
:
c
.
set_host_xfer_buffer_ops
(
copy_operation
)
def
handle_preemptions
(
self
,
preempted_req_ids
:
set
[
str
]):
"""Handle preempted requests for all sub-connectors."""
for
c
in
self
.
_connectors
:
c
.
handle_preemptions
(
preempted_req_ids
)
def
get_finished_count
(
self
)
->
int
|
None
:
# TODO(https://github.com/vllm-project/vllm/issues/33400)
# Currently no connectors return non-None
return
None
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events'
# method for the MultiConnector. It should be able to get events from
# multiple connectors, handling the case where only a subset of the
# requested connectors implements the 'get_kv_connector_kv_cache_events'
# WIP: https://github.com/vllm-project/vllm/pull/31811
# ==============================
# Scheduler-side methods
...
...
@@ -332,6 +349,27 @@ class MultiConnector(KVConnectorBase_V1):
for
c
in
self
.
_connectors
:
c
.
update_connector_output
(
connector_output
)
def
get_handshake_metadata
(
self
)
->
KVConnectorHandshakeMetadata
|
None
:
"""
Get the KVConnector handshake metadata from sub-connectors.
Returns the first non-None metadata from sub-connectors.
"""
for
c
in
self
.
_connectors
:
metadata
=
c
.
get_handshake_metadata
()
if
metadata
is
not
None
:
return
metadata
return
None
def
set_xfer_handshake_metadata
(
self
,
metadata
:
dict
[
int
,
KVConnectorHandshakeMetadata
]
)
->
None
:
"""
Set the KV connector handshake metadata for all sub-connectors.
This is needed to start the NIXL listener thread for NixlConnector.
"""
for
c
in
self
.
_connectors
:
c
.
set_xfer_handshake_metadata
(
metadata
)
def
request_finished
(
self
,
request
:
"Request"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment