Unverified Commit 4c6fd258 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

kv_transfer: Rename the shared storage connectors (#30201)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 03b91f72
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501 from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa: E501
SharedStorageConnectorMetadata, ExampleConnectorMetadata,
) )
from vllm.distributed.kv_transfer.kv_transfer_state import ( from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_initialized, ensure_kv_transfer_initialized,
...@@ -11,7 +11,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import ( ...@@ -11,7 +11,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
# Importing utils registers TestSharedStorageConnector with the factory # Importing utils registers TestExampleConnector with the factory
from .utils import create_vllm_config from .utils import create_vllm_config
...@@ -26,13 +26,13 @@ def _make_empty_scheduler_output(): ...@@ -26,13 +26,13 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
kv_connector_metadata=SharedStorageConnectorMetadata(), kv_connector_metadata=ExampleConnectorMetadata(),
) )
def test_kv_connector_mixin_clears_metadata(): def test_kv_connector_mixin_clears_metadata():
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector" vllm_config.kv_transfer_config.kv_connector = "TestExampleConnector"
vllm_config.kv_transfer_config.kv_role = "kv_both" vllm_config.kv_transfer_config.kv_role = "kv_both"
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit" vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit"
......
...@@ -77,9 +77,9 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool: ...@@ -77,9 +77,9 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool:
"https://github.com/ROCm/pytorch/issues/2822" "https://github.com/ROCm/pytorch/issues/2822"
), ),
) )
def test_multi_shared_storage_connector_consistency(): def test_multi_example_connector_consistency():
""" """
Tests that MultiConnector with two SharedStorageConnectors saves Tests that MultiConnector with two ExampleConnectors saves
identical KV cache data to separate storage locations. identical KV cache data to separate storage locations.
""" """
storage_1_path = Path("storage_1/") storage_1_path = Path("storage_1/")
...@@ -89,14 +89,14 @@ def test_multi_shared_storage_connector_consistency(): ...@@ -89,14 +89,14 @@ def test_multi_shared_storage_connector_consistency():
storage_1_path.mkdir() storage_1_path.mkdir()
storage_2_path.mkdir() storage_2_path.mkdir()
# Configure MultiConnector with two SharedStorageConnectors # Configure MultiConnector with two ExampleConnectors
kv_transfer_config = KVTransferConfig( kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector", kv_connector="MultiConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={ kv_connector_extra_config={
"connectors": [ "connectors": [
{ {
"kv_connector": "TestSharedStorageConnector", "kv_connector": "TestExampleConnector",
"kv_role": "kv_both", "kv_role": "kv_both",
"kv_connector_extra_config": { "kv_connector_extra_config": {
"shared_storage_path": str(storage_1_path), "shared_storage_path": str(storage_1_path),
...@@ -105,7 +105,7 @@ def test_multi_shared_storage_connector_consistency(): ...@@ -105,7 +105,7 @@ def test_multi_shared_storage_connector_consistency():
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils", "kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
}, },
{ {
"kv_connector": "TestSharedStorageConnector", "kv_connector": "TestExampleConnector",
"kv_role": "kv_both", "kv_role": "kv_both",
"kv_connector_extra_config": { "kv_connector_extra_config": {
"shared_storage_path": str(storage_2_path), "shared_storage_path": str(storage_2_path),
...@@ -427,7 +427,7 @@ class TestMultiConnectorStats: ...@@ -427,7 +427,7 @@ class TestMultiConnectorStats:
def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self): def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self):
"""Test that connectors without custom stats (return None) are skipped.""" """Test that connectors without custom stats (return None) are skipped."""
# SharedStorageConnector doesn't override build_kv_connector_stats, # ExampleConnector doesn't override build_kv_connector_stats,
# so it returns None and should be skipped # so it returns None and should be skipped
serialized_data = { serialized_data = {
"NixlConnector": { "NixlConnector": {
...@@ -440,7 +440,7 @@ class TestMultiConnectorStats: ...@@ -440,7 +440,7 @@ class TestMultiConnectorStats:
"num_failed_notifications": [], "num_failed_notifications": [],
} }
}, },
"SharedStorageConnector": {"data": {"some_field": [1, 2, 3]}}, "ExampleConnector": {"data": {"some_field": [1, 2, 3]}},
} }
stats = MultiConnector.build_kv_connector_stats(data=serialized_data) stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
...@@ -451,8 +451,8 @@ class TestMultiConnectorStats: ...@@ -451,8 +451,8 @@ class TestMultiConnectorStats:
assert len(stats.data) == 1 assert len(stats.data) == 1
assert "NixlConnector" in stats.data assert "NixlConnector" in stats.data
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats) assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
# SharedStorageConnector should be skipped (returns None) # ExampleConnector should be skipped (returns None)
assert "SharedStorageConnector" not in stats.data assert "ExampleConnector" not in stats.data
def test_build_kv_connector_stats_handles_malformed_data(self): def test_build_kv_connector_stats_handles_malformed_data(self):
"""Test that malformed data raises appropriate errors.""" """Test that malformed data raises appropriate errors."""
...@@ -527,13 +527,13 @@ class TestMultiConnectorStats: ...@@ -527,13 +527,13 @@ class TestMultiConnectorStats:
) )
stats2 = MultiKVConnectorStats( stats2 = MultiKVConnectorStats(
data={"SharedStorageConnector": KVConnectorStats(data={"field": [1, 2]})} data={"ExampleConnector": KVConnectorStats(data={"field": [1, 2]})}
) )
result = stats1.aggregate(stats2) result = stats1.aggregate(stats2)
assert "NixlConnector" in result.data assert "NixlConnector" in result.data
assert "SharedStorageConnector" in result.data assert "ExampleConnector" in result.data
def test_reduce(self): def test_reduce(self):
"""Test that reduce() correctly reduces all nested connector stats.""" """Test that reduce() correctly reduces all nested connector stats."""
......
...@@ -24,8 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ...@@ -24,8 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata, KVConnectorMetadata,
KVConnectorRole, KVConnectorRole,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa
SharedStorageConnector, ExampleConnector,
) )
from vllm.utils.hashing import sha256 from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
...@@ -264,10 +264,10 @@ def create_model_runner_output( ...@@ -264,10 +264,10 @@ def create_model_runner_output(
) )
class TestSharedStorageConnector(SharedStorageConnector): class TestExampleConnector(ExampleConnector):
def __init__(self, config: VllmConfig, role, kv_cache_config): def __init__(self, config: VllmConfig, role, kv_cache_config):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"] self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role) self._connector = ExampleConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int) self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector # Use a unique temp file per connector
self._event_file = ( self._event_file = (
...@@ -394,7 +394,7 @@ class MockKVConnector(KVConnectorBase_V1): ...@@ -394,7 +394,7 @@ class MockKVConnector(KVConnectorBase_V1):
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__ "TestExampleConnector", __name__, TestExampleConnector.__name__
) )
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
......
...@@ -32,7 +32,7 @@ class MMMeta: ...@@ -32,7 +32,7 @@ class MMMeta:
@dataclass @dataclass
class ECSharedStorageConnectorMetadata(ECConnectorMetadata): class ECExampleConnectorMetadata(ECConnectorMetadata):
mm_datas: list[MMMeta] mm_datas: list[MMMeta]
def __init__(self): def __init__(self):
...@@ -42,7 +42,7 @@ class ECSharedStorageConnectorMetadata(ECConnectorMetadata): ...@@ -42,7 +42,7 @@ class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
self.mm_datas.append(mm_data) self.mm_datas.append(mm_data)
class ECSharedStorageConnector(ECConnectorBase): class ECExampleConnector(ECConnectorBase):
# NOTE: This is Simple debug implementation of the EC connector. # NOTE: This is Simple debug implementation of the EC connector.
# It save / load the EC cache to / from the disk. # It save / load the EC cache to / from the disk.
...@@ -76,7 +76,7 @@ class ECSharedStorageConnector(ECConnectorBase): ...@@ -76,7 +76,7 @@ class ECSharedStorageConnector(ECConnectorBase):
# Get the metadata # Get the metadata
metadata: ECConnectorMetadata = self._get_connector_metadata() metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata) assert isinstance(metadata, ECExampleConnectorMetadata)
assert encoder_cache is not None assert encoder_cache is not None
if metadata is None: if metadata is None:
logger.warning( logger.warning(
...@@ -160,7 +160,7 @@ class ECSharedStorageConnector(ECConnectorBase): ...@@ -160,7 +160,7 @@ class ECSharedStorageConnector(ECConnectorBase):
Args: Args:
scheduler_output (SchedulerOutput): the scheduler output object. scheduler_output (SchedulerOutput): the scheduler output object.
""" """
meta = ECSharedStorageConnectorMetadata() meta = ECExampleConnectorMetadata()
for mm_hash, num_encoder_token in self._mm_datas_need_loads.items(): for mm_hash, num_encoder_token in self._mm_datas_need_loads.items():
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token)) meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
self._mm_datas_need_loads.clear() self._mm_datas_need_loads.clear()
......
...@@ -79,7 +79,7 @@ class ECConnectorFactory: ...@@ -79,7 +79,7 @@ class ECConnectorFactory:
# only load the files corresponding to the current connector. # only load the files corresponding to the current connector.
ECConnectorFactory.register_connector( ECConnectorFactory.register_connector(
"ECSharedStorageConnector", "ECExampleConnector",
"vllm.distributed.ec_transfer.ec_connector.shared_storage_connector", "vllm.distributed.ec_transfer.ec_connector.example_connector",
"ECSharedStorageConnector", "ECExampleConnector",
) )
...@@ -144,9 +144,9 @@ class KVConnectorFactory: ...@@ -144,9 +144,9 @@ class KVConnectorFactory:
# only load the files corresponding to the current connector. # only load the files corresponding to the current connector.
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"SharedStorageConnector", "ExampleConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", "vllm.distributed.kv_transfer.kv_connector.v1.example_connector",
"SharedStorageConnector", "ExampleConnector",
) )
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
......
...@@ -65,7 +65,7 @@ class ReqMeta: ...@@ -65,7 +65,7 @@ class ReqMeta:
@dataclass @dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata): class ExampleConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] = field(default_factory=list) requests: list[ReqMeta] = field(default_factory=list)
def add_request( def add_request(
...@@ -81,7 +81,7 @@ class SharedStorageConnectorMetadata(KVConnectorMetadata): ...@@ -81,7 +81,7 @@ class SharedStorageConnectorMetadata(KVConnectorMetadata):
) )
class SharedStorageConnector(KVConnectorBase_V1): class ExampleConnector(KVConnectorBase_V1):
# NOTE: This is Simple debug implementation of the KV connector. # NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk. # It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU # It does extra work which will overwrite the existing prefix-cache in GPU
...@@ -157,7 +157,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -157,7 +157,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# Get the metadata # Get the metadata
metadata: KVConnectorMetadata = self._get_connector_metadata() metadata: KVConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata) assert isinstance(metadata, ExampleConnectorMetadata)
if metadata is None: if metadata is None:
logger.warning( logger.warning(
...@@ -241,7 +241,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -241,7 +241,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...] return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
connector_metadata = self._get_connector_metadata() connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata) assert isinstance(connector_metadata, ExampleConnectorMetadata)
for request in connector_metadata.requests: for request in connector_metadata.requests:
if request.is_store: if request.is_store:
filename = self._generate_filename_debug( filename = self._generate_filename_debug(
...@@ -315,7 +315,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -315,7 +315,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
Args: Args:
scheduler_output (SchedulerOutput): the scheduler output object. scheduler_output (SchedulerOutput): the scheduler output object.
""" """
meta = SharedStorageConnectorMetadata() meta = ExampleConnectorMetadata()
total_need_load = 0 total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs: for new_req in scheduler_output.scheduled_new_reqs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment