Unverified Commit df8fd92b authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

chore: consistent name -- MultimodalEmbeddingCache (#5962)

parent fb62e2cf
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
"""Memory management utilities for Dynamo components.""" """Memory management utilities for Dynamo components."""
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
__all__ = ["EncoderCacheManager"] __all__ = ["MultimodalEmbeddingCacheManager"]
...@@ -8,7 +8,7 @@ A simple LRU cache for encoder embeddings (tensors). ...@@ -8,7 +8,7 @@ A simple LRU cache for encoder embeddings (tensors).
Maps content hash keys to tensors with capacity-based eviction. Maps content hash keys to tensors with capacity-based eviction.
Usage: Usage:
cache = EncoderCacheManager(capacity_bytes=4 * 1024**3) # 4GB cache = MultimodalEmbeddingCacheManager(capacity_bytes=4 * 1024**3) # 4GB
# Store embedding # Store embedding
cache.set("abc123", embedding_tensor) cache.set("abc123", embedding_tensor)
...@@ -26,7 +26,7 @@ import torch ...@@ -26,7 +26,7 @@ import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EncoderCacheManager: class MultimodalEmbeddingCacheManager:
""" """
LRU cache for encoder embeddings. LRU cache for encoder embeddings.
...@@ -56,7 +56,7 @@ class EncoderCacheManager: ...@@ -56,7 +56,7 @@ class EncoderCacheManager:
self._misses = 0 self._misses = 0
logger.info( logger.info(
f"EncoderCacheManager initialized: capacity={capacity_bytes / 1024**3:.2f}GB" f"MultimodalEmbeddingCacheManager initialized: capacity={capacity_bytes / 1024**3:.2f}GB"
) )
@staticmethod @staticmethod
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
""" """
Async Encoder Cache Async Encoder Cache
Async wrapper over EncoderCacheManager with request coalescing. Async wrapper over MultimodalEmbeddingCacheManager with request coalescing.
Prevents duplicate encoding when multiple requests arrive for the same content. Prevents duplicate encoding when multiple requests arrive for the same content.
Usage: Usage:
cache = EncoderCacheManager(capacity_bytes=4 * 1024**3) cache = MultimodalEmbeddingCacheManager(capacity_bytes=4 * 1024**3)
async_cache = AsyncEncoderCache(cache) async_cache = AsyncEncoderCache(cache)
# Get from cache or compute with coalescing # Get from cache or compute with coalescing
...@@ -21,7 +21,9 @@ from typing import Awaitable, Callable, Dict, Optional ...@@ -21,7 +21,9 @@ from typing import Awaitable, Callable, Dict, Optional
import torch import torch
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -43,7 +45,7 @@ def _suppress_unhandled_future_exception(future: asyncio.Future) -> None: ...@@ -43,7 +45,7 @@ def _suppress_unhandled_future_exception(future: asyncio.Future) -> None:
class AsyncEncoderCache: class AsyncEncoderCache:
""" """
Async wrapper with request coalescing over EncoderCacheManager. Async wrapper with request coalescing over MultimodalEmbeddingCacheManager.
Provides async get_or_compute that deduplicates concurrent requests Provides async get_or_compute that deduplicates concurrent requests
for the same key, ensuring only one encoding runs at a time per key. for the same key, ensuring only one encoding runs at a time per key.
...@@ -53,12 +55,12 @@ class AsyncEncoderCache: ...@@ -53,12 +55,12 @@ class AsyncEncoderCache:
asyncio event loop. All access must be from the same thread. asyncio event loop. All access must be from the same thread.
""" """
def __init__(self, cache: EncoderCacheManager): def __init__(self, cache: MultimodalEmbeddingCacheManager):
""" """
Initialize the async encoder cache. Initialize the async encoder cache.
Args: Args:
cache: Underlying EncoderCacheManager for storage. cache: Underlying MultimodalEmbeddingCacheManager for storage.
""" """
self._cache = cache self._cache = cache
self._in_flight: Dict[str, asyncio.Future[torch.Tensor]] = {} self._in_flight: Dict[str, asyncio.Future[torch.Tensor]] = {}
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Unit tests for EncoderCacheManager.""" """Unit tests for MultimodalEmbeddingCacheManager."""
import pytest import pytest
import torch import torch
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
class TestEncoderCacheManagerBasicOperations: class TestMultimodalEmbeddingCacheManagerBasicOperations:
"""Tests for basic get/set operations.""" """Tests for basic get/set operations."""
def test_set_and_get(self): def test_set_and_get(self):
"""Test basic set and get operations.""" """Test basic set and get operations."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024) # 1MB cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024) # 1MB
tensor = torch.randn(100, 100) # ~40KB for float32 tensor = torch.randn(100, 100) # ~40KB for float32
result = cache.set("key1", tensor) result = cache.set("key1", tensor)
...@@ -26,14 +28,14 @@ class TestEncoderCacheManagerBasicOperations: ...@@ -26,14 +28,14 @@ class TestEncoderCacheManagerBasicOperations:
def test_get_nonexistent_key(self): def test_get_nonexistent_key(self):
"""Test get returns None for nonexistent key.""" """Test get returns None for nonexistent key."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024) cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
result = cache.get("nonexistent") result = cache.get("nonexistent")
assert result is None assert result is None
def test_set_overwrites_existing_key(self): def test_set_overwrites_existing_key(self):
"""Test set overwrites existing key.""" """Test set overwrites existing key."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024) cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor1 = torch.randn(10, 10) tensor1 = torch.randn(10, 10)
tensor2 = torch.randn(10, 10) tensor2 = torch.randn(10, 10)
...@@ -45,7 +47,7 @@ class TestEncoderCacheManagerBasicOperations: ...@@ -45,7 +47,7 @@ class TestEncoderCacheManagerBasicOperations:
assert cache.stats["entries"] == 1 assert cache.stats["entries"] == 1
class TestEncoderCacheManagerLRUEviction: class TestMultimodalEmbeddingCacheManagerLRUEviction:
"""Tests for LRU eviction behavior.""" """Tests for LRU eviction behavior."""
def test_eviction_when_full(self): def test_eviction_when_full(self):
...@@ -53,7 +55,7 @@ class TestEncoderCacheManagerLRUEviction: ...@@ -53,7 +55,7 @@ class TestEncoderCacheManagerLRUEviction:
# Small capacity to force eviction # Small capacity to force eviction
tensor_size = 10 * 10 * 4 # 400 bytes for float32 tensor_size = 10 * 10 * 4 # 400 bytes for float32
capacity = tensor_size * 2 + 100 # Room for ~2 tensors capacity = tensor_size * 2 + 100 # Room for ~2 tensors
cache = EncoderCacheManager(capacity_bytes=capacity) cache = MultimodalEmbeddingCacheManager(capacity_bytes=capacity)
t1 = torch.randn(10, 10) t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10) t2 = torch.randn(10, 10)
...@@ -73,7 +75,7 @@ class TestEncoderCacheManagerLRUEviction: ...@@ -73,7 +75,7 @@ class TestEncoderCacheManagerLRUEviction:
"""Test that get() updates LRU order.""" """Test that get() updates LRU order."""
tensor_size = 10 * 10 * 4 # 400 bytes tensor_size = 10 * 10 * 4 # 400 bytes
capacity = tensor_size * 2 + 100 # Room for ~2 tensors capacity = tensor_size * 2 + 100 # Room for ~2 tensors
cache = EncoderCacheManager(capacity_bytes=capacity) cache = MultimodalEmbeddingCacheManager(capacity_bytes=capacity)
t1 = torch.randn(10, 10) t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10) t2 = torch.randn(10, 10)
...@@ -94,7 +96,7 @@ class TestEncoderCacheManagerLRUEviction: ...@@ -94,7 +96,7 @@ class TestEncoderCacheManagerLRUEviction:
def test_tensor_too_large_for_cache(self): def test_tensor_too_large_for_cache(self):
"""Test that tensor larger than capacity is not cached.""" """Test that tensor larger than capacity is not cached."""
cache = EncoderCacheManager(capacity_bytes=100) # Very small cache = MultimodalEmbeddingCacheManager(capacity_bytes=100) # Very small
tensor = torch.randn(100, 100) # ~40KB, way larger than capacity tensor = torch.randn(100, 100) # ~40KB, way larger than capacity
result = cache.set("key1", tensor) result = cache.set("key1", tensor)
...@@ -104,12 +106,12 @@ class TestEncoderCacheManagerLRUEviction: ...@@ -104,12 +106,12 @@ class TestEncoderCacheManagerLRUEviction:
assert cache.stats["entries"] == 0 assert cache.stats["entries"] == 0
class TestEncoderCacheManagerSizeTracking: class TestMultimodalEmbeddingCacheManagerSizeTracking:
"""Tests for memory size tracking.""" """Tests for memory size tracking."""
def test_current_bytes_tracking(self): def test_current_bytes_tracking(self):
"""Test that current_bytes is tracked correctly.""" """Test that current_bytes is tracked correctly."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024) cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
t1 = torch.randn(10, 10) # 400 bytes t1 = torch.randn(10, 10) # 400 bytes
t2 = torch.randn(20, 20) # 1600 bytes t2 = torch.randn(20, 20) # 1600 bytes
...@@ -125,7 +127,7 @@ class TestEncoderCacheManagerSizeTracking: ...@@ -125,7 +127,7 @@ class TestEncoderCacheManagerSizeTracking:
def test_size_updated_on_overwrite(self): def test_size_updated_on_overwrite(self):
"""Test that size is updated correctly when overwriting.""" """Test that size is updated correctly when overwriting."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024) cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
small_tensor = torch.randn(10, 10) # 400 bytes small_tensor = torch.randn(10, 10) # 400 bytes
large_tensor = torch.randn(20, 20) # 1600 bytes large_tensor = torch.randn(20, 20) # 1600 bytes
...@@ -140,12 +142,12 @@ class TestEncoderCacheManagerSizeTracking: ...@@ -140,12 +142,12 @@ class TestEncoderCacheManagerSizeTracking:
assert cache.stats["current_bytes"] > initial_size assert cache.stats["current_bytes"] > initial_size
class TestEncoderCacheManagerStats: class TestMultimodalEmbeddingCacheManagerStats:
"""Tests for statistics tracking.""" """Tests for statistics tracking."""
def test_hit_miss_tracking(self): def test_hit_miss_tracking(self):
"""Test hit and miss counting.""" """Test hit and miss counting."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024) cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
cache.set("key1", tensor) cache.set("key1", tensor)
...@@ -166,7 +168,7 @@ class TestEncoderCacheManagerStats: ...@@ -166,7 +168,7 @@ class TestEncoderCacheManagerStats:
def test_stats_content(self): def test_stats_content(self):
"""Test stats dictionary contains expected keys.""" """Test stats dictionary contains expected keys."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024) cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
cache.set("key1", tensor) cache.set("key1", tensor)
...@@ -186,7 +188,7 @@ class TestEncoderCacheManagerStats: ...@@ -186,7 +188,7 @@ class TestEncoderCacheManagerStats:
def test_utilization_calculation(self): def test_utilization_calculation(self):
"""Test utilization is calculated correctly.""" """Test utilization is calculated correctly."""
capacity = 1000 capacity = 1000
cache = EncoderCacheManager(capacity_bytes=capacity) cache = MultimodalEmbeddingCacheManager(capacity_bytes=capacity)
# Create tensor of known size # Create tensor of known size
# float32 = 4 bytes, so 25 elements = 100 bytes # float32 = 4 bytes, so 25 elements = 100 bytes
...@@ -198,12 +200,12 @@ class TestEncoderCacheManagerStats: ...@@ -198,12 +200,12 @@ class TestEncoderCacheManagerStats:
assert abs(stats["utilization"] - expected_utilization) < 0.001 assert abs(stats["utilization"] - expected_utilization) < 0.001
class TestEncoderCacheManagerContiguousTensor: class TestMultimodalEmbeddingCacheManagerContiguousTensor:
"""Tests for contiguous tensor requirement.""" """Tests for contiguous tensor requirement."""
def test_set_contiguous_tensor_succeeds(self): def test_set_contiguous_tensor_succeeds(self):
"""Test that contiguous tensors can be cached.""" """Test that contiguous tensors can be cached."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024) cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
assert tensor.is_contiguous() assert tensor.is_contiguous()
...@@ -212,7 +214,7 @@ class TestEncoderCacheManagerContiguousTensor: ...@@ -212,7 +214,7 @@ class TestEncoderCacheManagerContiguousTensor:
def test_set_non_contiguous_tensor_raises(self): def test_set_non_contiguous_tensor_raises(self):
"""Test that non-contiguous tensors raise AssertionError.""" """Test that non-contiguous tensors raise AssertionError."""
cache = EncoderCacheManager(capacity_bytes=1024 * 1024) cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
# Create a non-contiguous tensor via transpose # Create a non-contiguous tensor via transpose
tensor = torch.randn(10, 20).t() tensor = torch.randn(10, 20).t()
......
...@@ -8,7 +8,9 @@ import asyncio ...@@ -8,7 +8,9 @@ import asyncio
import pytest import pytest
import torch import torch
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
...@@ -18,7 +20,7 @@ class TestAsyncEncoderCacheBasicOperations: ...@@ -18,7 +20,7 @@ class TestAsyncEncoderCacheBasicOperations:
@pytest.fixture @pytest.fixture
def cache(self): def cache(self):
"""Create a cache for testing.""" """Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024) ecm = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm) return AsyncEncoderCache(ecm)
def test_sync_get_returns_none_for_missing_key(self, cache): def test_sync_get_returns_none_for_missing_key(self, cache):
...@@ -74,7 +76,7 @@ class TestAsyncEncoderCacheRequestCoalescing: ...@@ -74,7 +76,7 @@ class TestAsyncEncoderCacheRequestCoalescing:
@pytest.fixture @pytest.fixture
def cache(self): def cache(self):
"""Create a cache for testing.""" """Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024) ecm = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm) return AsyncEncoderCache(ecm)
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -137,7 +139,7 @@ class TestAsyncEncoderCacheExceptionHandling: ...@@ -137,7 +139,7 @@ class TestAsyncEncoderCacheExceptionHandling:
@pytest.fixture @pytest.fixture
def cache(self): def cache(self):
"""Create a cache for testing.""" """Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024) ecm = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm) return AsyncEncoderCache(ecm)
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -209,7 +211,7 @@ class TestAsyncEncoderCacheStats: ...@@ -209,7 +211,7 @@ class TestAsyncEncoderCacheStats:
@pytest.fixture @pytest.fixture
def cache(self): def cache(self):
"""Create a cache for testing.""" """Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024) ecm = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm) return AsyncEncoderCache(ecm)
def test_stats_includes_in_flight(self, cache): def test_stats_includes_in_flight(self, cache):
......
...@@ -14,7 +14,9 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -14,7 +14,9 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from tensorrt_llm.llmapi import DisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.common.multimodal.async_encoder_cache import EncoderCacheManager from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.trtllm.multimodal.cuda_ipc import extract_embeddings_from_handles from dynamo.trtllm.multimodal.cuda_ipc import extract_embeddings_from_handles
from dynamo.trtllm.multimodal.hasher import MultimodalHasher from dynamo.trtllm.multimodal.hasher import MultimodalHasher
...@@ -25,7 +27,7 @@ async def fetch_embeddings_from_encoder( ...@@ -25,7 +27,7 @@ async def fetch_embeddings_from_encoder(
image_urls: List[str], image_urls: List[str],
request: Dict[str, Any], request: Dict[str, Any],
encode_client: Any, encode_client: Any,
encoder_cache: Optional[EncoderCacheManager] = None, encoder_cache: Optional[MultimodalEmbeddingCacheManager] = None,
) -> Union[List[torch.Tensor], DisaggregatedParams]: ) -> Union[List[torch.Tensor], DisaggregatedParams]:
""" """
Fetch embeddings from remote encode worker. Fetch embeddings from remote encode worker.
...@@ -112,7 +114,7 @@ async def _remote_encode_full_epd( ...@@ -112,7 +114,7 @@ async def _remote_encode_full_epd(
async def _fetch_embeddings_with_cache( async def _fetch_embeddings_with_cache(
image_urls: List[str], image_urls: List[str],
request: Dict[str, Any], request: Dict[str, Any],
cache: EncoderCacheManager, cache: MultimodalEmbeddingCacheManager,
encode_fn: Callable[[Dict[str, Any]], DisaggregatedParams], encode_fn: Callable[[Dict[str, Any]], DisaggregatedParams],
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
""" """
......
...@@ -7,7 +7,9 @@ import logging ...@@ -7,7 +7,9 @@ import logging
from typing import Optional from typing import Optional
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
from dynamo.trtllm.request_handlers.handler_base import ( from dynamo.trtllm.request_handlers.handler_base import (
HandlerBase, HandlerBase,
...@@ -26,7 +28,7 @@ class AggregatedHandler(HandlerBase): ...@@ -26,7 +28,7 @@ class AggregatedHandler(HandlerBase):
def __init__( def __init__(
self, self,
config: RequestHandlerConfig, config: RequestHandlerConfig,
encoder_cache: Optional[EncoderCacheManager] = None, encoder_cache: Optional[MultimodalEmbeddingCacheManager] = None,
): ):
super().__init__(config) super().__init__(config)
self._encoder_cache = encoder_cache self._encoder_cache = encoder_cache
......
...@@ -5,7 +5,9 @@ import logging ...@@ -5,7 +5,9 @@ import logging
from typing import Optional from typing import Optional
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.encode_helper import EncodeHelper from dynamo.trtllm.encode_helper import EncodeHelper
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
...@@ -35,7 +37,7 @@ class RequestHandlerFactory: ...@@ -35,7 +37,7 @@ class RequestHandlerFactory:
encoder_cache = None encoder_cache = None
if config.encoder_cache_capacity_gb > 0: if config.encoder_cache_capacity_gb > 0:
capacity_bytes = int(config.encoder_cache_capacity_gb * 1024**3) capacity_bytes = int(config.encoder_cache_capacity_gb * 1024**3)
encoder_cache = EncoderCacheManager(capacity_bytes) encoder_cache = MultimodalEmbeddingCacheManager(capacity_bytes)
if config.disaggregation_mode.value == "prefill": if config.disaggregation_mode.value == "prefill":
return PrefillHandler(config, encoder_cache=encoder_cache) return PrefillHandler(config, encoder_cache=encoder_cache)
if config.disaggregation_mode.value == "prefill_and_decode": if config.disaggregation_mode.value == "prefill_and_decode":
...@@ -90,7 +92,7 @@ class PrefillHandler(HandlerBase): ...@@ -90,7 +92,7 @@ class PrefillHandler(HandlerBase):
def __init__( def __init__(
self, self,
config: RequestHandlerConfig, config: RequestHandlerConfig,
encoder_cache: Optional[EncoderCacheManager] = None, encoder_cache: Optional[MultimodalEmbeddingCacheManager] = None,
): ):
super().__init__(config) super().__init__(config)
self._encoder_cache = encoder_cache self._encoder_cache = encoder_cache
......
...@@ -10,7 +10,9 @@ import pytest ...@@ -10,7 +10,9 @@ import pytest
import torch import torch
from tensorrt_llm.llmapi import DisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
from dynamo.trtllm.multimodal.hasher import MultimodalHasher from dynamo.trtllm.multimodal.hasher import MultimodalHasher
...@@ -53,9 +55,9 @@ def create_mock_encode_client( ...@@ -53,9 +55,9 @@ def create_mock_encode_client(
@pytest.fixture @pytest.fixture
def encoder_cache() -> EncoderCacheManager: def encoder_cache() -> MultimodalEmbeddingCacheManager:
"""Create encoder cache with 10MB capacity.""" """Create encoder cache with 10MB capacity."""
return EncoderCacheManager(capacity_bytes=10 * 1024 * 1024) return MultimodalEmbeddingCacheManager(capacity_bytes=10 * 1024 * 1024)
class TestFetchEmbeddingsFromEncoder: class TestFetchEmbeddingsFromEncoder:
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
import pytest import pytest
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.trtllm.request_handlers.handlers import ( from dynamo.trtllm.request_handlers.handlers import (
AggregatedHandler, AggregatedHandler,
PrefillHandler, PrefillHandler,
...@@ -53,7 +55,7 @@ class TestRequestHandlerFactory: ...@@ -53,7 +55,7 @@ class TestRequestHandlerFactory:
assert isinstance(handler, PrefillHandler) assert isinstance(handler, PrefillHandler)
def test_prefill_handler_with_encoder_cache(self): def test_prefill_handler_with_encoder_cache(self):
"""Test factory creates PrefillHandler with EncoderCacheManager when capacity > 0.""" """Test factory creates PrefillHandler with MultimodalEmbeddingCacheManager when capacity > 0."""
mock_config = create_mock_request_handler_config( mock_config = create_mock_request_handler_config(
disaggregation_mode="prefill", disaggregation_mode="prefill",
encoder_cache_capacity_gb=1.0, encoder_cache_capacity_gb=1.0,
...@@ -62,7 +64,7 @@ class TestRequestHandlerFactory: ...@@ -62,7 +64,7 @@ class TestRequestHandlerFactory:
handler = factory.get_request_handler(mock_config) handler = factory.get_request_handler(mock_config)
assert isinstance(handler, PrefillHandler) assert isinstance(handler, PrefillHandler)
assert isinstance(handler._encoder_cache, EncoderCacheManager) assert isinstance(handler._encoder_cache, MultimodalEmbeddingCacheManager)
def test_prefill_handler_without_encoder_cache(self): def test_prefill_handler_without_encoder_cache(self):
"""Test factory creates PrefillHandler with no cache when capacity is 0.""" """Test factory creates PrefillHandler with no cache when capacity is 0."""
......
...@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch ...@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
def create_mock_encoder_cache() -> MagicMock: def create_mock_encoder_cache() -> MagicMock:
"""Create mock EncoderCacheManager.""" """Create mock MultimodalEmbeddingCacheManager."""
cache = MagicMock() cache = MagicMock()
cache.get = MagicMock(return_value=None) cache.get = MagicMock(return_value=None)
cache.set = MagicMock(return_value=True) cache.set = MagicMock(return_value=True)
......
...@@ -53,6 +53,7 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.trtllm \ ...@@ -53,6 +53,7 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.trtllm \
--modality "$MODALITY" \ --modality "$MODALITY" \
--custom-jinja-template "$CUSTOM_TEMPLATE" \ --custom-jinja-template "$CUSTOM_TEMPLATE" \
--encode-endpoint "$ENCODE_ENDPOINT" \ --encode-endpoint "$ENCODE_ENDPOINT" \
--disaggregation-mode prefill_and_decode \
--dyn-encoder-cache-capacity-gb "$DYN_ENCODER_CACHE_CAPACITY_GB" & --dyn-encoder-cache-capacity-gb "$DYN_ENCODER_CACHE_CAPACITY_GB" &
PD_PID_1=$! PD_PID_1=$!
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment