Unverified Commit 33b3cb8a authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

feat: embedding cache in agg vLLM node (#6153)

parent c8276cd2
......@@ -436,6 +436,32 @@ def setup_vllm_engine(config, stat_logger=None):
engine_args.create_model_config().get_diff_sampling_param()
)
# Configure ec_both mode with DynamoMultimodalEmbeddingCacheConnector.
# Must happen BEFORE engine setup so vLLM sees ec_transfer_config.
if (
not config.route_to_encoder
and config.multimodal_embedding_cache_capacity_gb > 0
):
from vllm.config import ECTransferConfig
logger.info(
"Configuring ec_both mode with DynamoMultimodalEmbeddingCacheConnector "
"(capacity=%.2f GB)",
config.multimodal_embedding_cache_capacity_gb,
)
instance_id = 0
engine_id = f"{config.namespace}.{config.component}.backend.{instance_id}"
engine_args.ec_transfer_config = ECTransferConfig(
engine_id=engine_id,
ec_role="ec_both",
ec_connector="DynamoMultimodalEmbeddingCacheConnector",
ec_connector_module_path="dynamo.vllm.multimodal_utils.multimodal_embedding_cache_connector",
ec_connector_extra_config={
"multimodal_embedding_cache_capacity_gb": config.multimodal_embedding_cache_capacity_gb,
},
)
logger.info("Configured ec_both with engine_id=%s", engine_id)
# Taken from build_async_engine_client_from_engine_args()
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from packaging.version import Version
from vllm import __version__ as _vllm_version
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase,
ECConnectorMetadata,
ECConnectorRole,
)
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.request import Request
MINIMUM_VLLM_VERSION = "0.17.0"
logger = logging.getLogger(__name__)
@dataclass
class MultimodalEmbeddingCacheConnectorMetadata(ECConnectorMetadata):
"""Commands from scheduler to worker for CPU embedding cache management."""
loads: list[str] = field(default_factory=list)
saves: list[str] = field(default_factory=list)
evicts: list[str] = field(default_factory=list)
class DynamoMultimodalEmbeddingCacheConnector(ECConnectorBase):
"""EC connector with scheduler-authoritative CPU embedding cache.
The scheduler maintains a logical LRU cache (OrderedDict) and issues
load/save/evict commands to the worker via ECConnectorMetadata. The
worker holds a plain dict[str, Tensor] on CPU and obeys commands
without independent caching decisions.
This mirrors vLLM's EncoderCacheManager pattern: the scheduler is the
single source of truth for cache state; the worker is a plain dict storage.
"""
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole) -> None:
if Version(_vllm_version) < Version(MINIMUM_VLLM_VERSION):
logger.warning(
"DynamoMultimodalEmbeddingCacheConnector requires vLLM >= %s, "
"but found %s. Some features may not work correctly.",
MINIMUM_VLLM_VERSION,
_vllm_version,
)
super().__init__(vllm_config=vllm_config, role=role)
transfer_config = vllm_config.ec_transfer_config
if transfer_config is None:
raise ValueError(
"ec_transfer_config must be set for DynamoMultimodalEmbeddingCacheConnector"
)
if "multimodal_embedding_cache_capacity_gb" not in (
transfer_config.ec_connector_extra_config or {}
):
raise ValueError(
"multimodal_embedding_cache_capacity_gb must be set in "
"ec_connector_extra_config for DynamoMultimodalEmbeddingCacheConnector"
)
capacity_gb: float = transfer_config.ec_connector_extra_config[
"multimodal_embedding_cache_capacity_gb"
]
# --- Scheduler-side: logical LRU for CPU embedding cache ---
# Mirrors EncoderCacheManager but for the CPU tier, tracking bytes.
hidden_size = vllm_config.model_config.get_hidden_size()
dtype_bytes = torch.tensor(
[], dtype=vllm_config.model_config.dtype
).element_size()
self._bytes_per_embed = hidden_size * dtype_bytes
self._capacity_bytes = int(capacity_gb * 1024**3)
self._cache_order: OrderedDict[str, int] = OrderedDict() # hash → size_bytes
self._num_used_bytes: int = 0
self._loads_this_step: set[str] = set()
self._saves_this_step: set[str] = set()
self._evicts_this_step: set[str] = set()
# --- Worker-side: dumb CPU tensor store ---
self._cpu_store: dict[str, torch.Tensor] = {}
logger.info(
"DynamoMultimodalEmbeddingCacheConnector initialized: "
"capacity_gb=%.2f, capacity_bytes=%d, bytes_per_embed=%d",
capacity_gb,
self._capacity_bytes,
self._bytes_per_embed,
)
# ==============================
# Scheduler-side methods
#
# vLLM scheduler call sequence per multimodal feature:
#
# 1. encoder_cache_manager.check_and_update_cache(request, i)
# → if True (GPU hit): skip entirely, neither method below is called.
#
# 2. has_cache_item(identifier)
# → if True (CPU hit): item goes to external_load_encoder_input
# → if False (CPU miss): item goes to encoder_inputs_to_schedule
#
# 3. update_state_after_alloc(request, i) is called for both paths.
# The two paths are mutually exclusive per hash within a step:
# - external_load_encoder_input → mm_hash IN _cache_order → load path
# - encoder_inputs_to_schedule → mm_hash NOT in _cache_order → save path
# ==============================
def has_cache_item(self, identifier: str) -> bool:
"""Check if an embedding is in the CPU cache, promoting it to MRU on hit.
Called by the scheduler only after the GPU encoder_cache_manager reports
a miss. A True return tells the scheduler to skip encoder compute and
load the embedding from the CPU store instead.
"""
if identifier in self._cache_order:
self._cache_order.move_to_end(identifier)
return True
return False
def update_state_after_alloc(self, request: "Request", index: int) -> None:
"""Record a load or save command for a multimodal feature.
Called by the scheduler after has_cache_item has already determined
the path. The _cache_order check here mirrors that decision:
CPU hit (mm_hash in _cache_order): mark for CPU→GPU load.
CPU miss (mm_hash not in _cache_order): evict LRU entries if needed,
then mark for GPU→CPU save so the worker persists the newly
computed embedding. Silently skips items larger than total capacity.
"""
mm_hash: str = request.mm_features[index].identifier
num_embeds: int = request.get_num_encoder_embeds(index)
size_bytes: int = num_embeds * self._bytes_per_embed
if mm_hash in self._cache_order:
self._cache_order.move_to_end(mm_hash)
self._loads_this_step.add(mm_hash)
return
if size_bytes > self._capacity_bytes:
return
self._saves_this_step.add(mm_hash)
while (
self._num_used_bytes + size_bytes > self._capacity_bytes
and self._cache_order
):
evicted_hash, evicted_bytes = self._cache_order.popitem(last=False)
self._num_used_bytes -= evicted_bytes
self._evicts_this_step.add(evicted_hash)
self._cache_order[mm_hash] = size_bytes
self._num_used_bytes += size_bytes
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> ECConnectorMetadata:
"""Flush accumulated load/save/evict commands into metadata for the worker."""
meta = MultimodalEmbeddingCacheConnectorMetadata(
loads=list(self._loads_this_step),
saves=list(self._saves_this_step),
evicts=list(self._evicts_this_step),
)
self._loads_this_step.clear()
self._saves_this_step.clear()
self._evicts_this_step.clear()
return meta
# ==============================
# Worker-side methods
#
# Called by the model runner each step with the metadata produced by
# build_connector_meta. The worker has no caching logic of its own;
# it simply obeys the scheduler's load/save/evict commands.
# ==============================
def start_load_caches(
self, encoder_cache: dict[str, torch.Tensor], **kwargs
) -> None:
"""Copy cached embeddings from CPU store to GPU encoder_cache, and evict
entries the scheduler marked for removal.
"""
metadata = self._get_connector_metadata()
assert isinstance(metadata, MultimodalEmbeddingCacheConnectorMetadata)
for mm_hash in metadata.loads:
if mm_hash in encoder_cache:
continue
if mm_hash in self._cpu_store:
encoder_cache[mm_hash] = self._cpu_store[mm_hash].to(
"cuda", non_blocking=True
)
else:
logger.warning(
"start_load_caches: hash %s not in cpu_store, skipping", mm_hash
)
for mm_hash in metadata.evicts:
self._cpu_store.pop(mm_hash, None)
def save_caches(
self, encoder_cache: dict[str, torch.Tensor], mm_hash: str, **kwargs
) -> None:
"""Copy a newly computed embedding from GPU encoder_cache to CPU store."""
metadata = self._get_connector_metadata()
assert isinstance(metadata, MultimodalEmbeddingCacheConnectorMetadata)
if mm_hash not in metadata.saves:
return
if mm_hash in self._cpu_store:
return
if mm_hash not in encoder_cache:
logger.warning(
"save_caches: hash %s in metadata.saves but not in encoder_cache",
mm_hash,
)
return
self._cpu_store[mm_hash] = encoder_cache[mm_hash].cpu()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for DynamoMultimodalEmbeddingCacheConnector."""
from unittest.mock import MagicMock, patch
import pytest
import torch
from dynamo.vllm.multimodal_utils import multimodal_embedding_cache_connector as mod
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.multimodal,
]
def _make_vllm_config(capacity_gb: float = 1.0) -> MagicMock:
config = MagicMock()
config.ec_transfer_config.ec_connector_extra_config = {
"multimodal_embedding_cache_capacity_gb": capacity_gb,
}
config.model_config.get_hidden_size.return_value = 4096
config.model_config.dtype = torch.float16
return config
class TestVersionCheck:
def test_warns_old_vllm(self):
with (
patch.object(mod, "_vllm_version", "0.16.5"),
patch.object(mod.ECConnectorBase, "__init__", return_value=None),
patch.object(mod.logger, "warning") as mock_warn,
):
connector = mod.DynamoMultimodalEmbeddingCacheConnector(
vllm_config=_make_vllm_config(),
role=MagicMock(),
)
assert connector is not None
mock_warn.assert_called_once()
assert mock_warn.call_args[0][1] == mod.MINIMUM_VLLM_VERSION
assert mock_warn.call_args[0][2] == "0.16.5"
class TestSchedulerSideLRU:
"""Test the scheduler-side logical LRU cache and metadata generation."""
def _make_connector(self, capacity_gb: float = 1.0):
with patch.object(mod.ECConnectorBase, "__init__", return_value=None):
return mod.DynamoMultimodalEmbeddingCacheConnector(
vllm_config=_make_vllm_config(capacity_gb),
role=MagicMock(),
)
def _make_request(self, hashes_and_embeds: list[tuple[str, int]]) -> MagicMock:
request = MagicMock()
features = []
for h, _ in hashes_and_embeds:
f = MagicMock()
f.identifier = h
features.append(f)
request.mm_features = features
def get_num_encoder_embeds(idx):
return hashes_and_embeds[idx][1]
request.get_num_encoder_embeds = get_num_encoder_embeds
return request
def test_has_cache_item_miss_then_hit(self):
conn = self._make_connector()
assert not conn.has_cache_item("hash_a")
request = self._make_request([("hash_a", 100)])
conn.update_state_after_alloc(request, 0)
assert conn.has_cache_item("hash_a")
def test_update_state_plans_save(self):
conn = self._make_connector()
request = self._make_request([("hash_a", 100)])
conn.update_state_after_alloc(request, 0)
scheduler_output = MagicMock()
meta = conn.build_connector_meta(scheduler_output)
assert isinstance(meta, mod.MultimodalEmbeddingCacheConnectorMetadata)
assert "hash_a" in meta.saves
assert meta.loads == []
assert meta.evicts == []
def test_update_state_plans_load_for_cached(self):
conn = self._make_connector()
request = self._make_request([("hash_a", 100)])
conn.update_state_after_alloc(request, 0)
conn.build_connector_meta(MagicMock())
conn.update_state_after_alloc(request, 0)
meta = conn.build_connector_meta(MagicMock())
assert "hash_a" in meta.loads
assert meta.saves == []
def test_eviction_under_pressure(self):
# 4096 hidden_size * 2 bytes (fp16) = 8192 bytes per embed
conn = self._make_connector()
bpe = conn._bytes_per_embed # 8192
# Set capacity to hold exactly 200 embeds worth of bytes
conn._capacity_bytes = 200 * bpe
req_a = self._make_request([("hash_a", 100)])
conn.update_state_after_alloc(req_a, 0)
conn.build_connector_meta(MagicMock())
req_b = self._make_request([("hash_b", 100)])
conn.update_state_after_alloc(req_b, 0)
conn.build_connector_meta(MagicMock())
assert conn._num_used_bytes == 200 * bpe
# Adding hash_c (100 embeds) should evict hash_a (LRU)
req_c = self._make_request([("hash_c", 100)])
conn.update_state_after_alloc(req_c, 0)
meta = conn.build_connector_meta(MagicMock())
assert "hash_c" in meta.saves
assert "hash_a" in meta.evicts
assert "hash_a" not in conn._cache_order
assert "hash_c" in conn._cache_order
def test_skip_oversized_item(self):
conn = self._make_connector()
bpe = conn._bytes_per_embed
conn._capacity_bytes = 50 * bpe
request = self._make_request([("huge_hash", 100)])
conn.update_state_after_alloc(request, 0)
meta = conn.build_connector_meta(MagicMock())
assert meta.saves == []
assert meta.loads == []
assert "huge_hash" not in conn._cache_order
......@@ -10,6 +10,13 @@ import pytest
from dynamo.vllm.worker_factory import EngineSetupResult, WorkerFactory
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_1,
pytest.mark.pre_merge,
]
def _make_config(**overrides) -> Mock:
"""Create a mock Config with all multimodal flags defaulting to False."""
......@@ -71,7 +78,7 @@ class TestCreate:
config = _make_config(multimodal_encode_worker=True)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event)
await factory.create(Mock(), config, shutdown_event, [])
factory._create_multimodal_encode_worker.assert_called_once() # type: ignore[union-attr]
......@@ -80,7 +87,7 @@ class TestCreate:
config = _make_config(multimodal_worker=True)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event)
await factory.create(Mock(), config, shutdown_event, [])
factory._create_multimodal_worker.assert_called_once() # type: ignore[union-attr]
......@@ -91,7 +98,7 @@ class TestCreate:
config = _make_config(multimodal_decode_worker=True)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event)
await factory.create(Mock(), config, shutdown_event, [])
factory._create_multimodal_worker.assert_called_once() # type: ignore[union-attr]
......@@ -100,6 +107,7 @@ class TestCreate:
config = _make_config(multimodal_worker=True)
runtime = Mock()
shutdown_event = asyncio.Event()
shutdown_endpoints: list = []
pre_created_engine: EngineSetupResult = (
Mock(),
Mock(),
......@@ -109,15 +117,23 @@ class TestCreate:
)
await factory.create(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
runtime,
config,
shutdown_event,
shutdown_endpoints,
pre_created_engine=pre_created_engine,
)
factory._create_multimodal_worker.assert_called_once_with( # type: ignore[union-attr]
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
runtime,
config,
shutdown_event,
shutdown_endpoints,
pre_created_engine=pre_created_engine,
)
@pytest.mark.asyncio
async def test_raises_when_no_multimodal_flag(self, factory: WorkerFactory) -> None:
config = _make_config()
with pytest.raises(ValueError, match="no multimodal worker type set"):
await factory.create(Mock(), config, asyncio.Event())
await factory.create(Mock(), config, asyncio.Event(), [])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment