Unverified Commit df8fd92b authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

chore: consistent name -- MultimodalEmbeddingCache (#5962)

parent fb62e2cf
......@@ -3,6 +3,8 @@
"""Memory management utilities for Dynamo components."""
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
__all__ = ["EncoderCacheManager"]
__all__ = ["MultimodalEmbeddingCacheManager"]
......@@ -8,7 +8,7 @@ A simple LRU cache for encoder embeddings (tensors).
Maps content hash keys to tensors with capacity-based eviction.
Usage:
cache = EncoderCacheManager(capacity_bytes=4 * 1024**3) # 4GB
cache = MultimodalEmbeddingCacheManager(capacity_bytes=4 * 1024**3) # 4GB
# Store embedding
cache.set("abc123", embedding_tensor)
......@@ -26,7 +26,7 @@ import torch
logger = logging.getLogger(__name__)
class EncoderCacheManager:
class MultimodalEmbeddingCacheManager:
"""
LRU cache for encoder embeddings.
......@@ -56,7 +56,7 @@ class EncoderCacheManager:
self._misses = 0
logger.info(
f"EncoderCacheManager initialized: capacity={capacity_bytes / 1024**3:.2f}GB"
f"MultimodalEmbeddingCacheManager initialized: capacity={capacity_bytes / 1024**3:.2f}GB"
)
@staticmethod
......
......@@ -4,11 +4,11 @@
"""
Async Encoder Cache
Async wrapper over EncoderCacheManager with request coalescing.
Async wrapper over MultimodalEmbeddingCacheManager with request coalescing.
Prevents duplicate encoding when multiple requests arrive for the same content.
Usage:
cache = EncoderCacheManager(capacity_bytes=4 * 1024**3)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=4 * 1024**3)
async_cache = AsyncEncoderCache(cache)
# Get from cache or compute with coalescing
......@@ -21,7 +21,9 @@ from typing import Awaitable, Callable, Dict, Optional
import torch
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
logger = logging.getLogger(__name__)
......@@ -43,7 +45,7 @@ def _suppress_unhandled_future_exception(future: asyncio.Future) -> None:
class AsyncEncoderCache:
"""
Async wrapper with request coalescing over EncoderCacheManager.
Async wrapper with request coalescing over MultimodalEmbeddingCacheManager.
Provides async get_or_compute that deduplicates concurrent requests
for the same key, ensuring only one encoding runs at a time per key.
......@@ -53,12 +55,12 @@ class AsyncEncoderCache:
asyncio event loop. All access must be from the same thread.
"""
def __init__(self, cache: EncoderCacheManager):
def __init__(self, cache: MultimodalEmbeddingCacheManager):
"""
Initialize the async encoder cache.
Args:
cache: Underlying EncoderCacheManager for storage.
cache: Underlying MultimodalEmbeddingCacheManager for storage.
"""
self._cache = cache
self._in_flight: Dict[str, asyncio.Future[torch.Tensor]] = {}
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for EncoderCacheManager."""
"""Unit tests for MultimodalEmbeddingCacheManager."""
import pytest
import torch
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
class TestEncoderCacheManagerBasicOperations:
class TestMultimodalEmbeddingCacheManagerBasicOperations:
"""Tests for basic get/set operations."""
def test_set_and_get(self):
"""Test basic set and get operations."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024) # 1MB
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024) # 1MB
tensor = torch.randn(100, 100) # ~40KB for float32
result = cache.set("key1", tensor)
......@@ -26,14 +28,14 @@ class TestEncoderCacheManagerBasicOperations:
def test_get_nonexistent_key(self):
"""Test get returns None for nonexistent key."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
result = cache.get("nonexistent")
assert result is None
def test_set_overwrites_existing_key(self):
"""Test set overwrites existing key."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor1 = torch.randn(10, 10)
tensor2 = torch.randn(10, 10)
......@@ -45,7 +47,7 @@ class TestEncoderCacheManagerBasicOperations:
assert cache.stats["entries"] == 1
class TestEncoderCacheManagerLRUEviction:
class TestMultimodalEmbeddingCacheManagerLRUEviction:
"""Tests for LRU eviction behavior."""
def test_eviction_when_full(self):
......@@ -53,7 +55,7 @@ class TestEncoderCacheManagerLRUEviction:
# Small capacity to force eviction
tensor_size = 10 * 10 * 4 # 400 bytes for float32
capacity = tensor_size * 2 + 100 # Room for ~2 tensors
cache = EncoderCacheManager(capacity_bytes=capacity)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=capacity)
t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10)
......@@ -73,7 +75,7 @@ class TestEncoderCacheManagerLRUEviction:
"""Test that get() updates LRU order."""
tensor_size = 10 * 10 * 4 # 400 bytes
capacity = tensor_size * 2 + 100 # Room for ~2 tensors
cache = EncoderCacheManager(capacity_bytes=capacity)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=capacity)
t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10)
......@@ -94,7 +96,7 @@ class TestEncoderCacheManagerLRUEviction:
def test_tensor_too_large_for_cache(self):
"""Test that tensor larger than capacity is not cached."""
cache = EncoderCacheManager(capacity_bytes=100) # Very small
cache = MultimodalEmbeddingCacheManager(capacity_bytes=100) # Very small
tensor = torch.randn(100, 100) # ~40KB, way larger than capacity
result = cache.set("key1", tensor)
......@@ -104,12 +106,12 @@ class TestEncoderCacheManagerLRUEviction:
assert cache.stats["entries"] == 0
class TestEncoderCacheManagerSizeTracking:
class TestMultimodalEmbeddingCacheManagerSizeTracking:
"""Tests for memory size tracking."""
def test_current_bytes_tracking(self):
"""Test that current_bytes is tracked correctly."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
t1 = torch.randn(10, 10) # 400 bytes
t2 = torch.randn(20, 20) # 1600 bytes
......@@ -125,7 +127,7 @@ class TestEncoderCacheManagerSizeTracking:
def test_size_updated_on_overwrite(self):
"""Test that size is updated correctly when overwriting."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
small_tensor = torch.randn(10, 10) # 400 bytes
large_tensor = torch.randn(20, 20) # 1600 bytes
......@@ -140,12 +142,12 @@ class TestEncoderCacheManagerSizeTracking:
assert cache.stats["current_bytes"] > initial_size
class TestEncoderCacheManagerStats:
class TestMultimodalEmbeddingCacheManagerStats:
"""Tests for statistics tracking."""
def test_hit_miss_tracking(self):
"""Test hit and miss counting."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(10, 10)
cache.set("key1", tensor)
......@@ -166,7 +168,7 @@ class TestEncoderCacheManagerStats:
def test_stats_content(self):
"""Test stats dictionary contains expected keys."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(10, 10)
cache.set("key1", tensor)
......@@ -186,7 +188,7 @@ class TestEncoderCacheManagerStats:
def test_utilization_calculation(self):
"""Test utilization is calculated correctly."""
capacity = 1000
cache = EncoderCacheManager(capacity_bytes=capacity)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=capacity)
# Create tensor of known size
# float32 = 4 bytes, so 25 elements = 100 bytes
......@@ -198,12 +200,12 @@ class TestEncoderCacheManagerStats:
assert abs(stats["utilization"] - expected_utilization) < 0.001
class TestEncoderCacheManagerContiguousTensor:
class TestMultimodalEmbeddingCacheManagerContiguousTensor:
"""Tests for contiguous tensor requirement."""
def test_set_contiguous_tensor_succeeds(self):
"""Test that contiguous tensors can be cached."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(10, 10)
assert tensor.is_contiguous()
......@@ -212,7 +214,7 @@ class TestEncoderCacheManagerContiguousTensor:
def test_set_non_contiguous_tensor_raises(self):
"""Test that non-contiguous tensors raise AssertionError."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024)
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
# Create a non-contiguous tensor via transpose
tensor = torch.randn(10, 20).t()
......
......@@ -8,7 +8,9 @@ import asyncio
import pytest
import torch
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
......@@ -18,7 +20,7 @@ class TestAsyncEncoderCacheBasicOperations:
@pytest.fixture
def cache(self):
"""Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024)
ecm = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm)
def test_sync_get_returns_none_for_missing_key(self, cache):
......@@ -74,7 +76,7 @@ class TestAsyncEncoderCacheRequestCoalescing:
@pytest.fixture
def cache(self):
"""Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024)
ecm = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm)
@pytest.mark.asyncio
......@@ -137,7 +139,7 @@ class TestAsyncEncoderCacheExceptionHandling:
@pytest.fixture
def cache(self):
"""Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024)
ecm = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm)
@pytest.mark.asyncio
......@@ -209,7 +211,7 @@ class TestAsyncEncoderCacheStats:
@pytest.fixture
def cache(self):
"""Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024)
ecm = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm)
def test_stats_includes_in_flight(self, cache):
......
......@@ -14,7 +14,9 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.common.multimodal.async_encoder_cache import EncoderCacheManager
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.trtllm.multimodal.cuda_ipc import extract_embeddings_from_handles
from dynamo.trtllm.multimodal.hasher import MultimodalHasher
......@@ -25,7 +27,7 @@ async def fetch_embeddings_from_encoder(
image_urls: List[str],
request: Dict[str, Any],
encode_client: Any,
encoder_cache: Optional[EncoderCacheManager] = None,
encoder_cache: Optional[MultimodalEmbeddingCacheManager] = None,
) -> Union[List[torch.Tensor], DisaggregatedParams]:
"""
Fetch embeddings from remote encode worker.
......@@ -112,7 +114,7 @@ async def _remote_encode_full_epd(
async def _fetch_embeddings_with_cache(
image_urls: List[str],
request: Dict[str, Any],
cache: EncoderCacheManager,
cache: MultimodalEmbeddingCacheManager,
encode_fn: Callable[[Dict[str, Any]], DisaggregatedParams],
) -> List[torch.Tensor]:
"""
......
......@@ -7,7 +7,9 @@ import logging
from typing import Optional
from dynamo._core import Context
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
from dynamo.trtllm.request_handlers.handler_base import (
HandlerBase,
......@@ -26,7 +28,7 @@ class AggregatedHandler(HandlerBase):
def __init__(
self,
config: RequestHandlerConfig,
encoder_cache: Optional[EncoderCacheManager] = None,
encoder_cache: Optional[MultimodalEmbeddingCacheManager] = None,
):
super().__init__(config)
self._encoder_cache = encoder_cache
......
......@@ -5,7 +5,9 @@ import logging
from typing import Optional
from dynamo._core import Context
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.encode_helper import EncodeHelper
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
......@@ -35,7 +37,7 @@ class RequestHandlerFactory:
encoder_cache = None
if config.encoder_cache_capacity_gb > 0:
capacity_bytes = int(config.encoder_cache_capacity_gb * 1024**3)
encoder_cache = EncoderCacheManager(capacity_bytes)
encoder_cache = MultimodalEmbeddingCacheManager(capacity_bytes)
if config.disaggregation_mode.value == "prefill":
return PrefillHandler(config, encoder_cache=encoder_cache)
if config.disaggregation_mode.value == "prefill_and_decode":
......@@ -90,7 +92,7 @@ class PrefillHandler(HandlerBase):
def __init__(
self,
config: RequestHandlerConfig,
encoder_cache: Optional[EncoderCacheManager] = None,
encoder_cache: Optional[MultimodalEmbeddingCacheManager] = None,
):
super().__init__(config)
self._encoder_cache = encoder_cache
......
......@@ -10,7 +10,9 @@ import pytest
import torch
from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
from dynamo.trtllm.multimodal.hasher import MultimodalHasher
......@@ -53,9 +55,9 @@ def create_mock_encode_client(
@pytest.fixture
def encoder_cache() -> EncoderCacheManager:
def encoder_cache() -> MultimodalEmbeddingCacheManager:
"""Create encoder cache with 10MB capacity."""
return EncoderCacheManager(capacity_bytes=10 * 1024 * 1024)
return MultimodalEmbeddingCacheManager(capacity_bytes=10 * 1024 * 1024)
class TestFetchEmbeddingsFromEncoder:
......
......@@ -5,7 +5,9 @@
import pytest
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.trtllm.request_handlers.handlers import (
AggregatedHandler,
PrefillHandler,
......@@ -53,7 +55,7 @@ class TestRequestHandlerFactory:
assert isinstance(handler, PrefillHandler)
def test_prefill_handler_with_encoder_cache(self):
"""Test factory creates PrefillHandler with EncoderCacheManager when capacity > 0."""
"""Test factory creates PrefillHandler with MultimodalEmbeddingCacheManager when capacity > 0."""
mock_config = create_mock_request_handler_config(
disaggregation_mode="prefill",
encoder_cache_capacity_gb=1.0,
......@@ -62,7 +64,7 @@ class TestRequestHandlerFactory:
handler = factory.get_request_handler(mock_config)
assert isinstance(handler, PrefillHandler)
assert isinstance(handler._encoder_cache, EncoderCacheManager)
assert isinstance(handler._encoder_cache, MultimodalEmbeddingCacheManager)
def test_prefill_handler_without_encoder_cache(self):
"""Test factory creates PrefillHandler with no cache when capacity is 0."""
......
......@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
def create_mock_encoder_cache() -> MagicMock:
"""Create mock EncoderCacheManager."""
"""Create mock MultimodalEmbeddingCacheManager."""
cache = MagicMock()
cache.get = MagicMock(return_value=None)
cache.set = MagicMock(return_value=True)
......
......@@ -53,6 +53,7 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.trtllm \
--modality "$MODALITY" \
--custom-jinja-template "$CUSTOM_TEMPLATE" \
--encode-endpoint "$ENCODE_ENDPOINT" \
--disaggregation-mode prefill_and_decode \
--dyn-encoder-cache-capacity-gb "$DYN_ENCODER_CACHE_CAPACITY_GB" &
PD_PID_1=$!
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment