Unverified Commit 4c6fd258 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

kv_transfer: Rename the shared storage connectors (#30201)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 03b91f72
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501
SharedStorageConnectorMetadata,
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa: E501
ExampleConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_initialized,
......@@ -11,7 +11,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
# Importing utils registers TestSharedStorageConnector with the factory
# Importing utils registers TestExampleConnector with the factory
from .utils import create_vllm_config
......@@ -26,13 +26,13 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
kv_connector_metadata=SharedStorageConnectorMetadata(),
kv_connector_metadata=ExampleConnectorMetadata(),
)
def test_kv_connector_mixin_clears_metadata():
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector"
vllm_config.kv_transfer_config.kv_connector = "TestExampleConnector"
vllm_config.kv_transfer_config.kv_role = "kv_both"
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit"
......
......@@ -77,9 +77,9 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool:
"https://github.com/ROCm/pytorch/issues/2822"
),
)
def test_multi_shared_storage_connector_consistency():
def test_multi_example_connector_consistency():
"""
Tests that MultiConnector with two SharedStorageConnectors saves
Tests that MultiConnector with two ExampleConnectors saves
identical KV cache data to separate storage locations.
"""
storage_1_path = Path("storage_1/")
......@@ -89,14 +89,14 @@ def test_multi_shared_storage_connector_consistency():
storage_1_path.mkdir()
storage_2_path.mkdir()
# Configure MultiConnector with two SharedStorageConnectors
# Configure MultiConnector with two ExampleConnectors
kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [
{
"kv_connector": "TestSharedStorageConnector",
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_1_path),
......@@ -105,7 +105,7 @@ def test_multi_shared_storage_connector_consistency():
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
},
{
"kv_connector": "TestSharedStorageConnector",
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_2_path),
......@@ -427,7 +427,7 @@ class TestMultiConnectorStats:
def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self):
"""Test that connectors without custom stats (return None) are skipped."""
# SharedStorageConnector doesn't override build_kv_connector_stats,
# ExampleConnector doesn't override build_kv_connector_stats,
# so it returns None and should be skipped
serialized_data = {
"NixlConnector": {
......@@ -440,7 +440,7 @@ class TestMultiConnectorStats:
"num_failed_notifications": [],
}
},
"SharedStorageConnector": {"data": {"some_field": [1, 2, 3]}},
"ExampleConnector": {"data": {"some_field": [1, 2, 3]}},
}
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
......@@ -451,8 +451,8 @@ class TestMultiConnectorStats:
assert len(stats.data) == 1
assert "NixlConnector" in stats.data
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
# SharedStorageConnector should be skipped (returns None)
assert "SharedStorageConnector" not in stats.data
# ExampleConnector should be skipped (returns None)
assert "ExampleConnector" not in stats.data
def test_build_kv_connector_stats_handles_malformed_data(self):
"""Test that malformed data raises appropriate errors."""
......@@ -527,13 +527,13 @@ class TestMultiConnectorStats:
)
stats2 = MultiKVConnectorStats(
data={"SharedStorageConnector": KVConnectorStats(data={"field": [1, 2]})}
data={"ExampleConnector": KVConnectorStats(data={"field": [1, 2]})}
)
result = stats1.aggregate(stats2)
assert "NixlConnector" in result.data
assert "SharedStorageConnector" in result.data
assert "ExampleConnector" in result.data
def test_reduce(self):
"""Test that reduce() correctly reduces all nested connector stats."""
......
......@@ -24,8 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector,
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa
ExampleConnector,
)
from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
......@@ -264,10 +264,10 @@ def create_model_runner_output(
)
class TestSharedStorageConnector(SharedStorageConnector):
class TestExampleConnector(ExampleConnector):
def __init__(self, config: VllmConfig, role, kv_cache_config):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role)
self._connector = ExampleConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = (
......@@ -394,7 +394,7 @@ class MockKVConnector(KVConnectorBase_V1):
KVConnectorFactory.register_connector(
"TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__
"TestExampleConnector", __name__, TestExampleConnector.__name__
)
KVConnectorFactory.register_connector(
......
......@@ -32,7 +32,7 @@ class MMMeta:
@dataclass
class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
class ECExampleConnectorMetadata(ECConnectorMetadata):
mm_datas: list[MMMeta]
def __init__(self):
......@@ -42,7 +42,7 @@ class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
self.mm_datas.append(mm_data)
class ECSharedStorageConnector(ECConnectorBase):
class ECExampleConnector(ECConnectorBase):
# NOTE: This is Simple debug implementation of the EC connector.
# It save / load the EC cache to / from the disk.
......@@ -76,7 +76,7 @@ class ECSharedStorageConnector(ECConnectorBase):
# Get the metadata
metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert isinstance(metadata, ECExampleConnectorMetadata)
assert encoder_cache is not None
if metadata is None:
logger.warning(
......@@ -160,7 +160,7 @@ class ECSharedStorageConnector(ECConnectorBase):
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = ECSharedStorageConnectorMetadata()
meta = ECExampleConnectorMetadata()
for mm_hash, num_encoder_token in self._mm_datas_need_loads.items():
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
self._mm_datas_need_loads.clear()
......
......@@ -79,7 +79,7 @@ class ECConnectorFactory:
# only load the files corresponding to the current connector.
ECConnectorFactory.register_connector(
"ECSharedStorageConnector",
"vllm.distributed.ec_transfer.ec_connector.shared_storage_connector",
"ECSharedStorageConnector",
"ECExampleConnector",
"vllm.distributed.ec_transfer.ec_connector.example_connector",
"ECExampleConnector",
)
......@@ -144,9 +144,9 @@ class KVConnectorFactory:
# only load the files corresponding to the current connector.
KVConnectorFactory.register_connector(
"SharedStorageConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
"SharedStorageConnector",
"ExampleConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.example_connector",
"ExampleConnector",
)
KVConnectorFactory.register_connector(
......
......@@ -65,7 +65,7 @@ class ReqMeta:
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
class ExampleConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] = field(default_factory=list)
def add_request(
......@@ -81,7 +81,7 @@ class SharedStorageConnectorMetadata(KVConnectorMetadata):
)
class SharedStorageConnector(KVConnectorBase_V1):
class ExampleConnector(KVConnectorBase_V1):
# NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU
......@@ -157,7 +157,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# Get the metadata
metadata: KVConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata)
assert isinstance(metadata, ExampleConnectorMetadata)
if metadata is None:
logger.warning(
......@@ -241,7 +241,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
assert isinstance(connector_metadata, ExampleConnectorMetadata)
for request in connector_metadata.requests:
if request.is_store:
filename = self._generate_filename_debug(
......@@ -315,7 +315,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = SharedStorageConnectorMetadata()
meta = ExampleConnectorMetadata()
total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment