Unverified Commit aca59674 authored by Seiji Eicher's avatar Seiji Eicher Committed by GitHub
Browse files

[KV Connector] Add missing method overrides to MultiConnector (#33292)


Signed-off-by: default avatarSeiji Eicher <seiji@anyscale.com>
parent 67a746e8
......@@ -190,27 +190,35 @@ def test_multi_example_connector_consistency():
)
events = get_connector_events()
# get_num_new_matched_tokens and update_state_after_alloc will be called
# on each connector in turn.
assert events["storage1-SCHEDULER"][:3] == [
# First event is set_xfer_handshake_metadata from initialization, then
# get_num_new_matched_tokens and update_state_after_alloc from generate().
assert events["storage1-SCHEDULER"][:4] == [
"set_xfer_handshake_metadata",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage1-WORKER"][:5] == [
# First three events are from initialization (register_kv_caches,
# set_host_xfer_buffer_ops, get_handshake_metadata), then generate() events.
assert events["storage1-WORKER"][:7] == [
"register_kv_caches",
"set_host_xfer_buffer_ops",
"get_handshake_metadata",
"bind_connector_metadata",
"start_load_kv",
"wait_for_layer_load",
"save_kv_layer",
]
assert events["storage2-SCHEDULER"][:3] == [
assert events["storage2-SCHEDULER"][:4] == [
"set_xfer_handshake_metadata",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage2-WORKER"][:5] == [
assert events["storage2-WORKER"][:7] == [
"register_kv_caches",
"set_host_xfer_buffer_ops",
"get_handshake_metadata",
"bind_connector_metadata",
"start_load_kv",
"wait_for_layer_load",
......@@ -297,6 +305,90 @@ def test_engine_id_conflict():
)
def test_multi_connector_handle_preemptions_integration():
"""
Integration test: verify MultiConnector delegates handle_preemptions
to all sub-connectors.
Uses TestExampleConnector which logs all method calls to temp files.
This test directly calls handle_preemptions on a MultiConnector with
TestExampleConnector sub-connectors and verifies the calls are logged.
"""
from tests.v1.kv_connector.unit.utils import (
create_scheduler,
create_vllm_config,
)
storage_path = Path(tempfile.mkdtemp())
try:
# Configure MultiConnector with two TestExampleConnectors
kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [
{
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_path / "s1"),
"name": "preempt1",
},
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
},
{
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_path / "s2"),
"name": "preempt2",
},
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
},
]
},
)
vllm_config = create_vllm_config(
block_size=16,
max_num_batched_tokens=100,
kv_connector_extra_config=kv_transfer_config.kv_connector_extra_config,
)
vllm_config.kv_transfer_config = kv_transfer_config
# Create scheduler - this initializes the MultiConnector with SCHEDULER role
scheduler = create_scheduler(vllm_config, num_blocks=10)
# Clear any events from initialization
get_connector_events()
# Directly call handle_preemptions on the scheduler's connector
# Note: handle_preemptions is normally a worker-side method, but we're
# testing the delegation behavior of MultiConnector here.
# The connector attribute contains the KV connector.
assert scheduler.connector is not None, "Scheduler should have a connector"
preempted_req_ids = {"req-1", "req-2", "req-3"}
scheduler.connector.handle_preemptions(preempted_req_ids)
# Verify both connectors received the handle_preemptions call
events = get_connector_events()
# Both SCHEDULER-role connectors should have logged handle_preemptions
assert "handle_preemptions" in events.get("preempt1-SCHEDULER", []), (
f"preempt1-SCHEDULER should have handle_preemptions call. "
f"Got events: {events}"
)
assert "handle_preemptions" in events.get("preempt2-SCHEDULER", []), (
f"preempt2-SCHEDULER should have handle_preemptions call. "
f"Got events: {events}"
)
finally:
# Cleanup
shutil.rmtree(storage_path, ignore_errors=True)
class TestMultiConnectorStats:
"""Tests for MultiConnector stats reconstruction and operations."""
......@@ -647,3 +739,39 @@ class TestMultiConnectorPreferCrossLayerBlocks:
MockConnector.__new__(MockConnector), # default False
]
assert mc.prefer_cross_layer_blocks is False
def test_multi_connector_overrides_all_base_methods():
"""
Ensure MultiConnector overrides all public methods from KVConnectorBase_V1.
"""
# These are fine to inherit from KVConnectorBase_V1
# TODO(https://github.com/vllm-project/vllm/pull/31811): Remove
# get_kv_connector_kv_cache_events from INHERITED_OK once implemented.
INHERITED_OK = {
"role",
"has_connector_metadata",
"get_kv_connector_kv_cache_events",
}
base_members = {
name for name in dir(KVConnectorBase_V1) if not name.startswith("_")
} - KVConnectorBase_V1.__abstractmethods__
missing = [
name
for name in sorted(base_members)
if name not in INHERITED_OK and name not in MultiConnector.__dict__
]
if missing:
pytest.fail(f"""
MultiConnector does not override these KVConnectorBase_V1 methods: {missing}
MultiConnector wraps other connectors and must delegate all methods.
Please add overrides that delegate to self._connectors.
Options:
1. Add delegation in MultiConnector (preferred)
2. Add to INHERITED_OK if the base implementation works correctly
""")
......@@ -12,7 +12,9 @@ from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
KVConnectorBase_V1,
KVConnectorHandshakeMetadata,
KVConnectorMetadata,
KVConnectorRole,
)
......@@ -272,11 +274,26 @@ class MultiConnector(KVConnectorBase_V1):
agg_block_ids |= c.get_block_ids_with_load_errors()
return agg_block_ids
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' method
# for the MultiConnector. It should be able to get events from multiple
# connectors, handling the case where only a subset of the requested connectors
# implements the 'get_kv_connector_kv_cache_events'
# Follow on PR from https://github.com/vllm-project/vllm/pull/28309#pullrequestreview-3566351082
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""Set xPU-specific copy ops for all sub-connectors."""
for c in self._connectors:
c.set_host_xfer_buffer_ops(copy_operation)
def handle_preemptions(self, preempted_req_ids: set[str]):
"""Handle preempted requests for all sub-connectors."""
for c in self._connectors:
c.handle_preemptions(preempted_req_ids)
def get_finished_count(self) -> int | None:
# TODO(https://github.com/vllm-project/vllm/issues/33400)
# Currently no connectors return non-None
return None
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events'
# method for the MultiConnector. It should be able to get events from
# multiple connectors, handling the case where only a subset of the
# requested connectors implements the 'get_kv_connector_kv_cache_events'
# WIP: https://github.com/vllm-project/vllm/pull/31811
# ==============================
# Scheduler-side methods
......@@ -332,6 +349,27 @@ class MultiConnector(KVConnectorBase_V1):
for c in self._connectors:
c.update_connector_output(connector_output)
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata from sub-connectors.
Returns the first non-None metadata from sub-connectors.
"""
for c in self._connectors:
metadata = c.get_handshake_metadata()
if metadata is not None:
return metadata
return None
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for all sub-connectors.
This is needed to start the NIXL listener thread for NixlConnector.
"""
for c in self._connectors:
c.set_xfer_handshake_metadata(metadata)
def request_finished(
self,
request: "Request",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment