Unverified Commit 4c6fd258 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

kv_transfer: Rename the shared storage connectors (#30201)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 03b91f72
...@@ -47,6 +47,6 @@ docker run \ ...@@ -47,6 +47,6 @@ docker run \
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
pytest -v -s v1/structured_output pytest -v -s v1/structured_output
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py --ignore=v1/spec_decode/test_speculators_eagle3.py pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py --ignore=v1/spec_decode/test_speculators_eagle3.py
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_example_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py
pytest -v -s v1/test_serial_utils.py pytest -v -s v1/test_serial_utils.py
' '
...@@ -32,14 +32,14 @@ Design doc: <https://docs.google.com/document/d/1aed8KtC6XkXtdoV87pWT0a8OJlZ-Cpn ...@@ -32,14 +32,14 @@ Design doc: <https://docs.google.com/document/d/1aed8KtC6XkXtdoV87pWT0a8OJlZ-Cpn
## 2 Usage Example ## 2 Usage Example
The current reference pathway is **SharedStorageConnector**. The current reference pathway is **ExampleConnector**.
Below ready-to-run scripts shows the workflow: Below ready-to-run scripts shows the workflow:
1 Encoder instance + 1 PD instance: 1 Encoder instance + 1 PD instance:
`examples/online_serving/disaggregated_encoder/shared_storage_connector/disagg_encoder_example.sh` `examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh`
1 Encoder instance + 1 Prefill instance + 1 Decode instance: 1 Encoder instance + 1 Prefill instance + 1 Decode instance:
`examples/online_serving/disaggregated_encoder/shared_storage_connector/disagg_epd_example.sh` `examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh`
--- ---
......
...@@ -21,14 +21,14 @@ Please refer to [examples/online_serving/disaggregated_prefill.sh](../../example ...@@ -21,14 +21,14 @@ Please refer to [examples/online_serving/disaggregated_prefill.sh](../../example
Now supports 5 types of connectors: Now supports 5 types of connectors:
- **SharedStorageConnector**: refer to [examples/offline_inference/disaggregated-prefill-v1/run.sh](../../examples/offline_inference/disaggregated-prefill-v1/run.sh) for the example usage of SharedStorageConnector disaggregated prefilling. - **ExampleConnector**: refer to [examples/offline_inference/disaggregated-prefill-v1/run.sh](../../examples/offline_inference/disaggregated-prefill-v1/run.sh) for the example usage of ExampleConnector disaggregated prefilling.
- **LMCacheConnectorV1**: refer to [examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh](../../examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh) for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission. - **LMCacheConnectorV1**: refer to [examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh](../../examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh) for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission.
- **NixlConnector**: refer to [tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md). - **NixlConnector**: refer to [tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md).
- **P2pNcclConnector**: refer to [examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh](../../examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh) for the example usage of P2pNcclConnector disaggregated prefilling. - **P2pNcclConnector**: refer to [examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh](../../examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh) for the example usage of P2pNcclConnector disaggregated prefilling.
- **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as: - **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as:
```bash ```bash
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}' --kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"ExampleConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
``` ```
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as: For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
......
...@@ -30,7 +30,7 @@ def main(): ...@@ -30,7 +30,7 @@ def main():
max_num_batched_tokens=64, max_num_batched_tokens=64,
max_num_seqs=16, max_num_seqs=16,
kv_transfer_config=KVTransferConfig( kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="ExampleConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"}, kv_connector_extra_config={"shared_storage_path": "local_storage"},
), ),
......
...@@ -26,7 +26,7 @@ def main(): ...@@ -26,7 +26,7 @@ def main():
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=0.8, gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig( kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="ExampleConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"}, kv_connector_extra_config={"shared_storage_path": "local_storage"},
), ),
......
...@@ -10,7 +10,7 @@ It demonstrates vLLM's ability to recover from KV load failures in both synchron ...@@ -10,7 +10,7 @@ It demonstrates vLLM's ability to recover from KV load failures in both synchron
- `decode_example.py` – performs the decode stage. Accepts: - `decode_example.py` – performs the decode stage. Accepts:
- `--simulate-failure`: simulates KV load failure using a custom connector. - `--simulate-failure`: simulates KV load failure using a custom connector.
- `--async-load`: enables asynchronous KV loading mode. - `--async-load`: enables asynchronous KV loading mode.
- `rogue_shared_storage_connector.py` – defines `RogueSharedStorageConnector`, a subclass of `SharedStorageConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request. - `load_recovery_example_connector.py` – defines `LoadRecoveryExampleConnector`, a subclass of `ExampleConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request.
- `run.sh` – orchestrates the test: runs the prefill stage, then three decode stages: - `run.sh` – orchestrates the test: runs the prefill stage, then three decode stages:
1. Normal decode (baseline). 1. Normal decode (baseline).
2. Decode with simulated sync KV load failure. 2. Decode with simulated sync KV load failure.
...@@ -20,7 +20,7 @@ It demonstrates vLLM's ability to recover from KV load failures in both synchron ...@@ -20,7 +20,7 @@ It demonstrates vLLM's ability to recover from KV load failures in both synchron
## How It Works ## How It Works
- The test dynamically loads `RogueSharedStorageConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector. - The test dynamically loads `LoadRecoveryExampleConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector.
- The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode. - The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode.
- If recovery fails, the script prints a unified diff of the output mismatch and exits with error. - If recovery fails, the script prints a unified diff of the output mismatch and exits with error.
......
...@@ -35,13 +35,13 @@ def main(): ...@@ -35,13 +35,13 @@ def main():
if args.simulate_failure: if args.simulate_failure:
ktc = KVTransferConfig( ktc = KVTransferConfig(
kv_connector="RogueSharedStorageConnector", kv_connector="LoadRecoveryExampleConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={ kv_connector_extra_config={
"shared_storage_path": "local_storage", "shared_storage_path": "local_storage",
"async_load": args.async_load, "async_load": args.async_load,
}, },
kv_connector_module_path="rogue_shared_storage_connector", kv_connector_module_path="load_recovery_example_connector",
) )
out_file = ( out_file = (
"async_decode_recovered_output.txt" "async_decode_recovered_output.txt"
...@@ -50,7 +50,7 @@ def main(): ...@@ -50,7 +50,7 @@ def main():
) )
else: else:
ktc = KVTransferConfig( ktc = KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="ExampleConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={ kv_connector_extra_config={
"shared_storage_path": "local_storage", "shared_storage_path": "local_storage",
......
...@@ -10,9 +10,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ...@@ -10,9 +10,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata, KVConnectorMetadata,
KVConnectorRole, KVConnectorRole,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
SharedStorageConnector, ExampleConnector,
SharedStorageConnectorMetadata, ExampleConnectorMetadata,
) )
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
...@@ -26,15 +26,15 @@ logging.basicConfig(level=logging.INFO) ...@@ -26,15 +26,15 @@ logging.basicConfig(level=logging.INFO)
@dataclass @dataclass
class RogueSharedStorageConnectorMetadata(SharedStorageConnectorMetadata): class LoadRecoveryExampleConnectorMetadata(ExampleConnectorMetadata):
req_to_block_ids: dict[str, set[int]] = field(default_factory=dict) req_to_block_ids: dict[str, set[int]] = field(default_factory=dict)
@classmethod @classmethod
def from_base(cls, base: SharedStorageConnectorMetadata): def from_base(cls, base: ExampleConnectorMetadata):
return cls(requests=base.requests) return cls(requests=base.requests)
class RogueSharedStorageConnector(SharedStorageConnector): class LoadRecoveryExampleConnector(ExampleConnector):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role) super().__init__(vllm_config=vllm_config, role=role)
self._async_load = vllm_config.kv_transfer_config.get_from_extra_config( self._async_load = vllm_config.kv_transfer_config.get_from_extra_config(
...@@ -45,7 +45,7 @@ class RogueSharedStorageConnector(SharedStorageConnector): ...@@ -45,7 +45,7 @@ class RogueSharedStorageConnector(SharedStorageConnector):
self._req_to_block_ids: dict[str, list[int]] = dict() self._req_to_block_ids: dict[str, list[int]] = dict()
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, RogueSharedStorageConnectorMetadata) assert isinstance(connector_metadata, LoadRecoveryExampleConnectorMetadata)
index, failed_request = next( index, failed_request = next(
( (
(i, x) (i, x)
...@@ -84,7 +84,7 @@ class RogueSharedStorageConnector(SharedStorageConnector): ...@@ -84,7 +84,7 @@ class RogueSharedStorageConnector(SharedStorageConnector):
) -> tuple[set[str] | None, set[str] | None]: ) -> tuple[set[str] | None, set[str] | None]:
if self._async_load: if self._async_load:
meta = self._get_connector_metadata() meta = self._get_connector_metadata()
assert isinstance(meta, RogueSharedStorageConnectorMetadata) assert isinstance(meta, LoadRecoveryExampleConnectorMetadata)
if meta.req_to_block_ids: if meta.req_to_block_ids:
return None, set(meta.req_to_block_ids) return None, set(meta.req_to_block_ids)
...@@ -126,9 +126,9 @@ class RogueSharedStorageConnector(SharedStorageConnector): ...@@ -126,9 +126,9 @@ class RogueSharedStorageConnector(SharedStorageConnector):
) -> KVConnectorMetadata: ) -> KVConnectorMetadata:
if not self._async_load: if not self._async_load:
base = super().build_connector_meta(scheduler_output) base = super().build_connector_meta(scheduler_output)
meta = RogueSharedStorageConnectorMetadata.from_base(base) meta = LoadRecoveryExampleConnectorMetadata.from_base(base)
else: else:
meta = RogueSharedStorageConnectorMetadata() meta = LoadRecoveryExampleConnectorMetadata()
if self._requests_need_load: if self._requests_need_load:
for req_id, request in self._requests_need_load.items(): for req_id, request in self._requests_need_load.items():
meta.add_request( meta.add_request(
......
...@@ -26,7 +26,7 @@ def main(): ...@@ -26,7 +26,7 @@ def main():
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=0.8, gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig( kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="ExampleConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"}, kv_connector_extra_config={"shared_storage_path": "local_storage"},
), ),
......
...@@ -50,12 +50,12 @@ The vllm instances and `disagg_encoder_proxy` supports local URIs with ```{"url" ...@@ -50,12 +50,12 @@ The vllm instances and `disagg_encoder_proxy` supports local URIs with ```{"url"
## EC connector and KV transfer ## EC connector and KV transfer
The `ECSharedStorageConnector` is used to store the encoder cache on local disk and facilitate transfer. To enable the encoder disaggregation feature, add the following configuration: The `ECExampleonnector` is used to store the encoder cache on local disk and facilitate transfer. To enable the encoder disaggregation feature, add the following configuration:
```bash ```bash
# Add to encoder instance: # Add to encoder instance:
--ec-transfer-config '{ --ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector", "ec_connector": "ECExampleConnector",
"ec_role": "ec_producer", "ec_role": "ec_producer",
"ec_connector_extra_config": { "ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
...@@ -64,7 +64,7 @@ The `ECSharedStorageConnector` is used to store the encoder cache on local disk ...@@ -64,7 +64,7 @@ The `ECSharedStorageConnector` is used to store the encoder cache on local disk
# Add to prefill/prefill+decode instance: # Add to prefill/prefill+decode instance:
--ec-transfer-config '{ --ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector", "ec_connector": "ECExampleConnector",
"ec_role": "ec_consumer", "ec_role": "ec_consumer",
"ec_connector_extra_config": { "ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
......
...@@ -102,7 +102,7 @@ CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ ...@@ -102,7 +102,7 @@ CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
--max-num-seqs 128 \ --max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{ --ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector", "ec_connector": "ECExampleConnector",
"ec_role": "ec_producer", "ec_role": "ec_producer",
"ec_connector_extra_config": { "ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
...@@ -126,7 +126,7 @@ vllm serve "$MODEL" \ ...@@ -126,7 +126,7 @@ vllm serve "$MODEL" \
--max-num-seqs 128 \ --max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{ --ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector", "ec_connector": "ECExampleConnector",
"ec_role": "ec_consumer", "ec_role": "ec_consumer",
"ec_connector_extra_config": { "ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
......
...@@ -96,7 +96,7 @@ CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ ...@@ -96,7 +96,7 @@ CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
--max-num-seqs 128 \ --max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{ --ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector", "ec_connector": "ECExampleConnector",
"ec_role": "ec_producer", "ec_role": "ec_producer",
"ec_connector_extra_config": { "ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
...@@ -117,7 +117,7 @@ CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \ ...@@ -117,7 +117,7 @@ CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
--max-num-seqs 128 \ --max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{ --ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector", "ec_connector": "ECExampleConnector",
"ec_role": "ec_consumer", "ec_role": "ec_consumer",
"ec_connector_extra_config": { "ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
......
...@@ -61,7 +61,7 @@ def test_get_kv_connector_cache_layout_with_multi_connector(): ...@@ -61,7 +61,7 @@ def test_get_kv_connector_cache_layout_with_multi_connector():
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={ kv_connector_extra_config={
"connectors": [ "connectors": [
{"kv_connector": "SharedStorageConnector", "kv_role": "kv_both"}, {"kv_connector": "ExampleConnector", "kv_role": "kv_both"},
{"kv_connector": "NixlConnector", "kv_role": "kv_both"}, {"kv_connector": "NixlConnector", "kv_role": "kv_both"},
] ]
}, },
......
...@@ -1536,7 +1536,7 @@ def create_scheduler_with_priority( ...@@ -1536,7 +1536,7 @@ def create_scheduler_with_priority(
) )
kv_transfer_config = ( kv_transfer_config = (
KVTransferConfig( KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="ExampleConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"}, kv_connector_extra_config={"shared_storage_path": "local_storage"},
) )
...@@ -1552,7 +1552,7 @@ def create_scheduler_with_priority( ...@@ -1552,7 +1552,7 @@ def create_scheduler_with_priority(
ec_transfer_config = ( ec_transfer_config = (
ECTransferConfig( ECTransferConfig(
ec_connector="ECSharedStorageConnector", ec_connector="ECExampleConnector",
ec_role=ec_role, ec_role=ec_role,
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"}, ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
) )
...@@ -2413,7 +2413,7 @@ def _assert_right_ec_connector_metadata( ...@@ -2413,7 +2413,7 @@ def _assert_right_ec_connector_metadata(
metadata_dict = {mm_data.mm_hash: mm_data for mm_data in metadata.mm_datas} metadata_dict = {mm_data.mm_hash: mm_data for mm_data in metadata.mm_datas}
# Check all required identifiers exist in metadata; and no extra # Check all required identifiers exist in metadata; and no extra
# In ECSharedStorageConnector format # In ECExampleConnector format
# NOTE: even having same identifier, the mm_features can be different # NOTE: even having same identifier, the mm_features can be different
# since their mm_position can be in different offsets, etc # since their mm_position can be in different offsets, etc
identifiers_dict = {f.identifier for f in mm_features_list} identifiers_dict = {f.identifier for f in mm_features_list}
......
...@@ -108,7 +108,7 @@ def create_scheduler( ...@@ -108,7 +108,7 @@ def create_scheduler(
) )
elif use_kv_connector: elif use_kv_connector:
kv_transfer_config = KVTransferConfig( kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="ExampleConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"}, kv_connector_extra_config={"shared_storage_path": "local_storage"},
) )
...@@ -121,7 +121,7 @@ def create_scheduler( ...@@ -121,7 +121,7 @@ def create_scheduler(
ec_transfer_config = ( ec_transfer_config = (
ECTransferConfig( ECTransferConfig(
ec_connector="ECSharedStorageConnector", ec_connector="ECExampleConnector",
ec_role=ec_role, ec_role=ec_role,
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"}, ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
) )
......
...@@ -148,7 +148,7 @@ run_epd_1e_1pd() { ...@@ -148,7 +148,7 @@ run_epd_1e_1pd() {
--max-num-seqs 128 \ --max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{ --ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector", "ec_connector": "ECExampleConnector",
"ec_role": "ec_producer", "ec_role": "ec_producer",
"ec_connector_extra_config": { "ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
...@@ -167,7 +167,7 @@ run_epd_1e_1pd() { ...@@ -167,7 +167,7 @@ run_epd_1e_1pd() {
--max-num-seqs 128 \ --max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{ --ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector", "ec_connector": "ECExampleConnector",
"ec_role": "ec_consumer", "ec_role": "ec_consumer",
"ec_connector_extra_config": { "ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
...@@ -348,7 +348,7 @@ run_epd_1e_1p_1d() { ...@@ -348,7 +348,7 @@ run_epd_1e_1p_1d() {
--max-num-seqs 128 \ --max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{ --ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector", "ec_connector": "ECExampleConnector",
"ec_role": "ec_producer", "ec_role": "ec_producer",
"ec_connector_extra_config": { "ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
...@@ -369,7 +369,7 @@ run_epd_1e_1p_1d() { ...@@ -369,7 +369,7 @@ run_epd_1e_1p_1d() {
--max-num-seqs 128 \ --max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{ --ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector", "ec_connector": "ECExampleConnector",
"ec_role": "ec_consumer", "ec_role": "ec_consumer",
"ec_connector_extra_config": { "ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
Unit tests for ECSharedStorageConnector. Unit tests for ECExampleConnector.
""" """
import os import os
...@@ -13,9 +13,9 @@ import torch ...@@ -13,9 +13,9 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorRole from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorRole
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import ( from vllm.distributed.ec_transfer.ec_connector.example_connector import (
ECSharedStorageConnector, ECExampleConnector,
ECSharedStorageConnectorMetadata, ECExampleConnectorMetadata,
MMMeta, MMMeta,
) )
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
...@@ -81,12 +81,12 @@ def mock_request_with_3_mm(): ...@@ -81,12 +81,12 @@ def mock_request_with_3_mm():
# ------------------ Unit Tests ------------------ # # ------------------ Unit Tests ------------------ #
class TestECSharedStorageConnectorBasics: class TestECExampleConnectorBasics:
"""Test basic EC connector functionality.""" """Test basic EC connector functionality."""
def test_initialization_producer(self, mock_vllm_config_producer, temp_storage): def test_initialization_producer(self, mock_vllm_config_producer, temp_storage):
"""Test connector initializes correctly as producer.""" """Test connector initializes correctly as producer."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
...@@ -98,7 +98,7 @@ class TestECSharedStorageConnectorBasics: ...@@ -98,7 +98,7 @@ class TestECSharedStorageConnectorBasics:
def test_initialization_consumer(self, mock_vllm_config_consumer, temp_storage): def test_initialization_consumer(self, mock_vllm_config_consumer, temp_storage):
"""Test connector initializes correctly as consumer.""" """Test connector initializes correctly as consumer."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer, vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
...@@ -109,11 +109,11 @@ class TestECSharedStorageConnectorBasics: ...@@ -109,11 +109,11 @@ class TestECSharedStorageConnectorBasics:
def test_role_assignment(self, mock_vllm_config_producer): def test_role_assignment(self, mock_vllm_config_producer):
"""Test role is correctly assigned.""" """Test role is correctly assigned."""
scheduler_connector = ECSharedStorageConnector( scheduler_connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
worker_connector = ECSharedStorageConnector( worker_connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
...@@ -133,7 +133,7 @@ class TestCacheExistence: ...@@ -133,7 +133,7 @@ class TestCacheExistence:
): ):
"""Test has_caches returns True when all 3 caches exist.""" """Test has_caches returns True when all 3 caches exist."""
# Test for producer first # Test for producer first
producer = ECSharedStorageConnector( producer = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
...@@ -154,7 +154,7 @@ class TestCacheExistence: ...@@ -154,7 +154,7 @@ class TestCacheExistence:
assert all(producer_result), f"Expected all True, got {producer_result}" assert all(producer_result), f"Expected all True, got {producer_result}"
# Also test consumer can check if cache exists # Also test consumer can check if cache exists
consumer = ECSharedStorageConnector( consumer = ECExampleConnector(
vllm_config=mock_vllm_config_consumer, vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
...@@ -170,7 +170,7 @@ class TestCacheExistence: ...@@ -170,7 +170,7 @@ class TestCacheExistence:
self, mock_vllm_config_producer, mock_request_with_3_mm self, mock_vllm_config_producer, mock_request_with_3_mm
): ):
"""Test has_caches returns False when no caches exist.""" """Test has_caches returns False when no caches exist."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
...@@ -186,7 +186,7 @@ class TestCacheExistence: ...@@ -186,7 +186,7 @@ class TestCacheExistence:
self, mock_vllm_config_producer, mock_request_with_3_mm self, mock_vllm_config_producer, mock_request_with_3_mm
): ):
"""Test has_caches with some caches existing (1 of 3).""" """Test has_caches with some caches existing (1 of 3)."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
...@@ -213,7 +213,7 @@ class TestStateManagement: ...@@ -213,7 +213,7 @@ class TestStateManagement:
self, mock_vllm_config_producer, mock_request_with_3_mm self, mock_vllm_config_producer, mock_request_with_3_mm
): ):
"""Test state update after allocation for 3 MM items.""" """Test state update after allocation for 3 MM items."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
...@@ -238,7 +238,7 @@ class TestStateManagement: ...@@ -238,7 +238,7 @@ class TestStateManagement:
self, mock_vllm_config_producer, mock_request_with_3_mm self, mock_vllm_config_producer, mock_request_with_3_mm
): ):
"""Test metadata building for 3 MM items.""" """Test metadata building for 3 MM items."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
...@@ -252,7 +252,7 @@ class TestStateManagement: ...@@ -252,7 +252,7 @@ class TestStateManagement:
metadata = connector.build_connector_meta(scheduler_output) metadata = connector.build_connector_meta(scheduler_output)
# Assert # Assert
assert isinstance(metadata, ECSharedStorageConnectorMetadata) assert isinstance(metadata, ECExampleConnectorMetadata)
assert len(metadata.mm_datas) == 3 assert len(metadata.mm_datas) == 3
assert metadata.mm_datas[0].mm_hash == "img_hash_1" assert metadata.mm_datas[0].mm_hash == "img_hash_1"
assert metadata.mm_datas[0].num_token == 100 assert metadata.mm_datas[0].num_token == 100
...@@ -266,7 +266,7 @@ class TestStateManagement: ...@@ -266,7 +266,7 @@ class TestStateManagement:
def test_build_connector_meta_empty(self, mock_vllm_config_producer): def test_build_connector_meta_empty(self, mock_vllm_config_producer):
"""Test metadata building with empty state.""" """Test metadata building with empty state."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
...@@ -274,14 +274,14 @@ class TestStateManagement: ...@@ -274,14 +274,14 @@ class TestStateManagement:
scheduler_output = Mock(spec=SchedulerOutput) scheduler_output = Mock(spec=SchedulerOutput)
metadata = connector.build_connector_meta(scheduler_output) metadata = connector.build_connector_meta(scheduler_output)
assert isinstance(metadata, ECSharedStorageConnectorMetadata) assert isinstance(metadata, ECExampleConnectorMetadata)
assert len(metadata.mm_datas) == 0 assert len(metadata.mm_datas) == 0
def test_state_cleared_after_metadata_build( def test_state_cleared_after_metadata_build(
self, mock_vllm_config_producer, mock_request_with_3_mm self, mock_vllm_config_producer, mock_request_with_3_mm
): ):
"""Test that state is properly cleared after building metadata.""" """Test that state is properly cleared after building metadata."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
...@@ -310,7 +310,7 @@ class TestCacheSaving: ...@@ -310,7 +310,7 @@ class TestCacheSaving:
self, mock_vllm_config_producer, mock_request_with_3_mm, temp_storage self, mock_vllm_config_producer, mock_request_with_3_mm, temp_storage
): ):
"""Test cache saving as producer for 3 different MM items.""" """Test cache saving as producer for 3 different MM items."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
...@@ -336,7 +336,7 @@ class TestCacheSaving: ...@@ -336,7 +336,7 @@ class TestCacheSaving:
def test_save_caches_consumer_skips(self, mock_vllm_config_consumer): def test_save_caches_consumer_skips(self, mock_vllm_config_consumer):
"""Test cache saving is skipped for consumer.""" """Test cache saving is skipped for consumer."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer, vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
...@@ -366,7 +366,7 @@ class TestCacheLoading: ...@@ -366,7 +366,7 @@ class TestCacheLoading:
): ):
"""Test consumer loads 3 caches from storage.""" """Test consumer loads 3 caches from storage."""
# First, create producer to save caches # First, create producer to save caches
producer = ECSharedStorageConnector( producer = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
...@@ -379,13 +379,13 @@ class TestCacheLoading: ...@@ -379,13 +379,13 @@ class TestCacheLoading:
producer.save_caches(saved_caches, mm_hash) producer.save_caches(saved_caches, mm_hash)
# Now consumer loads # Now consumer loads
consumer = ECSharedStorageConnector( consumer = ECExampleConnector(
vllm_config=mock_vllm_config_consumer, vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
# Setup metadata for all 3 # Setup metadata for all 3
metadata = ECSharedStorageConnectorMetadata() metadata = ECExampleConnectorMetadata()
for mm_hash in mm_hashes: for mm_hash in mm_hashes:
metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100)) metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100))
consumer.bind_connector_metadata(metadata) consumer.bind_connector_metadata(metadata)
...@@ -410,7 +410,7 @@ class TestCacheLoading: ...@@ -410,7 +410,7 @@ class TestCacheLoading:
): ):
"""Test cache loading skips already cached items.""" """Test cache loading skips already cached items."""
# Setup: producer saves cache # Setup: producer saves cache
producer = ECSharedStorageConnector( producer = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
...@@ -420,12 +420,12 @@ class TestCacheLoading: ...@@ -420,12 +420,12 @@ class TestCacheLoading:
producer.save_caches({mm_hash: saved_cache}, mm_hash) producer.save_caches({mm_hash: saved_cache}, mm_hash)
# Consumer setup # Consumer setup
consumer = ECSharedStorageConnector( consumer = ECExampleConnector(
vllm_config=mock_vllm_config_consumer, vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
metadata = ECSharedStorageConnectorMetadata() metadata = ECExampleConnectorMetadata()
metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100)) metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100))
consumer.bind_connector_metadata(metadata) consumer.bind_connector_metadata(metadata)
...@@ -444,13 +444,13 @@ class TestCacheLoading: ...@@ -444,13 +444,13 @@ class TestCacheLoading:
def test_start_load_caches_empty_metadata(self, mock_vllm_config_consumer): def test_start_load_caches_empty_metadata(self, mock_vllm_config_consumer):
"""Test loading with empty metadata does nothing.""" """Test loading with empty metadata does nothing."""
consumer = ECSharedStorageConnector( consumer = ECExampleConnector(
vllm_config=mock_vllm_config_consumer, vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
# Setup empty metadata # Setup empty metadata
metadata = ECSharedStorageConnectorMetadata() metadata = ECExampleConnectorMetadata()
consumer.bind_connector_metadata(metadata) consumer.bind_connector_metadata(metadata)
# Load (should not raise) # Load (should not raise)
...@@ -466,7 +466,7 @@ class TestFilenameGeneration: ...@@ -466,7 +466,7 @@ class TestFilenameGeneration:
def test_generate_foldername(self, mock_vllm_config_producer, temp_storage): def test_generate_foldername(self, mock_vllm_config_producer, temp_storage):
"""Test folder name generation.""" """Test folder name generation."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
...@@ -479,7 +479,7 @@ class TestFilenameGeneration: ...@@ -479,7 +479,7 @@ class TestFilenameGeneration:
def test_generate_filename(self, mock_vllm_config_producer, temp_storage): def test_generate_filename(self, mock_vllm_config_producer, temp_storage):
"""Test filename generation.""" """Test filename generation."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
...@@ -493,7 +493,7 @@ class TestFilenameGeneration: ...@@ -493,7 +493,7 @@ class TestFilenameGeneration:
def test_generate_filename_consistency(self, mock_vllm_config_producer): def test_generate_filename_consistency(self, mock_vllm_config_producer):
"""Test filename generation is consistent.""" """Test filename generation is consistent."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
...@@ -510,12 +510,12 @@ class TestMetadataBindingLifecycle: ...@@ -510,12 +510,12 @@ class TestMetadataBindingLifecycle:
def test_bind_connector_metadata(self, mock_vllm_config_consumer): def test_bind_connector_metadata(self, mock_vllm_config_consumer):
"""Test binding connector metadata.""" """Test binding connector metadata."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer, vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
metadata = ECSharedStorageConnectorMetadata() metadata = ECExampleConnectorMetadata()
metadata.add_mm_data(MMMeta.make_meta("hash_1", 100)) metadata.add_mm_data(MMMeta.make_meta("hash_1", 100))
connector.bind_connector_metadata(metadata) connector.bind_connector_metadata(metadata)
...@@ -524,12 +524,12 @@ class TestMetadataBindingLifecycle: ...@@ -524,12 +524,12 @@ class TestMetadataBindingLifecycle:
def test_clear_connector_metadata(self, mock_vllm_config_consumer): def test_clear_connector_metadata(self, mock_vllm_config_consumer):
"""Test clearing connector metadata.""" """Test clearing connector metadata."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer, vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
metadata = ECSharedStorageConnectorMetadata() metadata = ECExampleConnectorMetadata()
connector.bind_connector_metadata(metadata) connector.bind_connector_metadata(metadata)
connector.clear_connector_metadata() connector.clear_connector_metadata()
...@@ -538,12 +538,12 @@ class TestMetadataBindingLifecycle: ...@@ -538,12 +538,12 @@ class TestMetadataBindingLifecycle:
def test_get_connector_metadata(self, mock_vllm_config_consumer): def test_get_connector_metadata(self, mock_vllm_config_consumer):
"""Test getting connector metadata.""" """Test getting connector metadata."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer, vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
metadata = ECSharedStorageConnectorMetadata() metadata = ECExampleConnectorMetadata()
connector.bind_connector_metadata(metadata) connector.bind_connector_metadata(metadata)
retrieved = connector._get_connector_metadata() retrieved = connector._get_connector_metadata()
...@@ -552,7 +552,7 @@ class TestMetadataBindingLifecycle: ...@@ -552,7 +552,7 @@ class TestMetadataBindingLifecycle:
def test_get_connector_metadata_not_set(self, mock_vllm_config_consumer): def test_get_connector_metadata_not_set(self, mock_vllm_config_consumer):
"""Test getting metadata when not set raises.""" """Test getting metadata when not set raises."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer, vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
...@@ -566,7 +566,7 @@ class TestEdgeCases: ...@@ -566,7 +566,7 @@ class TestEdgeCases:
def test_save_empty_cache(self, mock_vllm_config_producer): def test_save_empty_cache(self, mock_vllm_config_producer):
"""Test saving empty tensor.""" """Test saving empty tensor."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
...@@ -579,12 +579,12 @@ class TestEdgeCases: ...@@ -579,12 +579,12 @@ class TestEdgeCases:
def test_load_nonexistent_cache(self, mock_vllm_config_consumer): def test_load_nonexistent_cache(self, mock_vllm_config_consumer):
"""Test loading cache that doesn't exist raises error.""" """Test loading cache that doesn't exist raises error."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer, vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER, role=ECConnectorRole.WORKER,
) )
metadata = ECSharedStorageConnectorMetadata() metadata = ECExampleConnectorMetadata()
metadata.add_mm_data(MMMeta.make_meta("nonexistent_hash", 100)) metadata.add_mm_data(MMMeta.make_meta("nonexistent_hash", 100))
connector.bind_connector_metadata(metadata) connector.bind_connector_metadata(metadata)
...@@ -596,7 +596,7 @@ class TestEdgeCases: ...@@ -596,7 +596,7 @@ class TestEdgeCases:
def test_has_caches_empty_request(self, mock_vllm_config_producer): def test_has_caches_empty_request(self, mock_vllm_config_producer):
"""Test has_caches with request that has no MM data.""" """Test has_caches with request that has no MM data."""
connector = ECSharedStorageConnector( connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
......
...@@ -507,7 +507,7 @@ def test_encoder_instance_zero_kv_cache( ...@@ -507,7 +507,7 @@ def test_encoder_instance_zero_kv_cache(
) )
kv_transfer_config = ( kv_transfer_config = (
KVTransferConfig( KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="ExampleConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"}, kv_connector_extra_config={"shared_storage_path": "local_storage"},
) )
...@@ -515,7 +515,7 @@ def test_encoder_instance_zero_kv_cache( ...@@ -515,7 +515,7 @@ def test_encoder_instance_zero_kv_cache(
else None else None
) )
ec_transfer_config = ECTransferConfig( ec_transfer_config = ECTransferConfig(
ec_connector="ECSharedStorageConnector", ec_connector="ECExampleConnector",
ec_role=ec_role, ec_role=ec_role,
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test_encoder"}, ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test_encoder"},
) )
......
...@@ -218,12 +218,12 @@ def test_internal_connector_uses_new_signature(): ...@@ -218,12 +218,12 @@ def test_internal_connector_uses_new_signature():
Test that internal connectors (registered in factory) always use the new Test that internal connectors (registered in factory) always use the new
signature and get kv_cache_config. signature and get kv_cache_config.
""" """
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
SharedStorageConnector, ExampleConnector,
) )
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector" vllm_config.kv_transfer_config.kv_connector = "ExampleConnector"
scheduler = create_scheduler(vllm_config) scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config kv_cache_config = scheduler.kv_cache_config
...@@ -233,7 +233,7 @@ def test_internal_connector_uses_new_signature(): ...@@ -233,7 +233,7 @@ def test_internal_connector_uses_new_signature():
) )
assert connector is not None assert connector is not None
assert isinstance(connector, SharedStorageConnector) assert isinstance(connector, ExampleConnector)
assert connector._kv_cache_config is not None assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config assert connector._kv_cache_config == kv_cache_config
......
...@@ -119,16 +119,16 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]): ...@@ -119,16 +119,16 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
) )
def test_shared_storage_connector_hashes(tmp_path): def test_shared_storage_connector_hashes(tmp_path):
""" """
Tests that SharedStorageConnector saves KV to the storage locations Tests that ExampleConnector saves KV to the storage locations
with proper hashes; that are unique for inputs with identical text but with proper hashes; that are unique for inputs with identical text but
different images (same size), or same multiple images but different orders. different images (same size), or same multiple images but different orders.
""" """
# Using tmp_path as the storage path to store KV # Using tmp_path as the storage path to store KV
print(f"KV storage path at: {str(tmp_path)}") print(f"KV storage path at: {str(tmp_path)}")
# Configure the SharedStorageConnector # Configure the ExampleConnector
kv_transfer_config = KVTransferConfig( kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="ExampleConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": str(tmp_path)}, kv_connector_extra_config={"shared_storage_path": str(tmp_path)},
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment