Unverified Commit 838ba140 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

feat: async encoder cache impl (#5676)

parent 7d3c67f0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Multimodal utilities for Dynamo components."""
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
__all__ = ["AsyncEncoderCache"]
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Async Encoder Cache
Async wrapper over EncoderCacheManager with request coalescing.
Prevents duplicate encoding when multiple requests arrive for the same content.
Usage:
cache = EncoderCacheManager(capacity_bytes=4 * 1024**3)
async_cache = AsyncEncoderCache(cache)
# Get from cache or compute with coalescing
tensor = await async_cache.get_or_compute("hash123", encoder.encode)
"""
import asyncio
import logging
from typing import Awaitable, Callable, Dict, Optional
import torch
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
logger = logging.getLogger(__name__)
def _suppress_unhandled_future_exception(future: asyncio.Future) -> None:
"""
Callback to prevent 'Future exception was never retrieved' warning.
When a Future has set_exception() called but no one awaits it (e.g., single
caller that gets the exception via re-raise), asyncio warns. This callback
retrieves the exception to suppress that warning.
"""
if future.done() and not future.cancelled():
try:
future.exception()
except asyncio.CancelledError:
pass
class AsyncEncoderCache:
"""
Async wrapper with request coalescing over EncoderCacheManager.
Provides async get_or_compute that deduplicates concurrent requests
for the same key, ensuring only one encoding runs at a time per key.
Thread Safety:
This class is NOT thread-safe. It is designed to run within a single
asyncio event loop. All access must be from the same thread.
"""
def __init__(self, cache: EncoderCacheManager):
"""
Initialize the async encoder cache.
Args:
cache: Underlying EncoderCacheManager for storage.
"""
self._cache = cache
self._in_flight: Dict[str, asyncio.Future[torch.Tensor]] = {}
def get(self, key: str) -> Optional[torch.Tensor]:
"""
Synchronous get from underlying cache.
Args:
key: Cache key.
Returns:
Cached tensor or None if not found.
"""
return self._cache.get(key)
async def get_or_compute(
self,
key: str,
compute_fn: Callable[[], Awaitable[torch.Tensor]],
) -> torch.Tensor:
"""
Get from cache or compute with request coalescing.
If the key is in cache, returns immediately.
If another coroutine is already computing this key, waits for that result.
Otherwise, computes and caches the result.
Args:
key: Cache key (typically content hash).
compute_fn: Async function to compute the tensor if not cached.
Returns:
The cached or computed tensor.
Raises:
Exception: Re-raises any exception from compute_fn.
"""
# Check cache first
cached = self._cache.get(key)
if cached is not None:
return cached
# Wait if already in-flight
if key in self._in_flight:
logger.debug(f"Waiting for in-flight computation: key={key[:16]}...")
return await self._in_flight[key]
# Compute with coalescing
future: asyncio.Future[torch.Tensor] = asyncio.Future()
future.add_done_callback(_suppress_unhandled_future_exception)
self._in_flight[key] = future
try:
tensor = await compute_fn()
self._cache.set(key, tensor)
future.set_result(tensor)
return tensor
except Exception as e:
future.set_exception(e)
raise
finally:
del self._in_flight[key]
@property
def stats(self) -> dict:
"""
Get cache statistics from underlying cache.
Returns:
Dictionary with cache stats.
"""
base_stats = self._cache.stats
base_stats["in_flight"] = len(self._in_flight)
return base_stats
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for AsyncEncoderCache."""
import asyncio
import pytest
import torch
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
class TestAsyncEncoderCacheBasicOperations:
"""Tests for basic operations."""
@pytest.fixture
def cache(self):
"""Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm)
def test_sync_get_returns_none_for_missing_key(self, cache):
"""Test sync get returns None for nonexistent key."""
assert cache.get("nonexistent") is None
def test_sync_get_returns_cached_tensor(self, cache):
"""Test sync get returns tensor after it's cached."""
tensor = torch.randn(10, 10)
cache._cache.set("key1", tensor)
result = cache.get("key1")
assert torch.equal(result, tensor)
@pytest.mark.asyncio
async def test_get_or_compute_caches_result(self, cache):
"""Test get_or_compute caches the computed result."""
tensor = torch.randn(10, 10)
async def compute():
return tensor
result = await cache.get_or_compute("key1", compute)
assert torch.equal(result, tensor)
# Should be in cache now
cached = cache.get("key1")
assert cached is not None
assert torch.equal(cached, tensor)
@pytest.mark.asyncio
async def test_get_or_compute_returns_cached(self, cache):
"""Test get_or_compute returns cached value without computing."""
tensor = torch.randn(10, 10)
cache._cache.set("key1", tensor)
compute_called = False
async def compute():
nonlocal compute_called
compute_called = True
return torch.randn(10, 10)
result = await cache.get_or_compute("key1", compute)
assert torch.equal(result, tensor)
assert not compute_called
class TestAsyncEncoderCacheRequestCoalescing:
"""Tests for request coalescing behavior."""
@pytest.fixture
def cache(self):
"""Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm)
@pytest.mark.asyncio
async def test_concurrent_requests_coalesce(self, cache):
"""Test that concurrent requests for same key only compute once."""
compute_count = 0
tensor = torch.randn(10, 10)
compute_started = asyncio.Event()
compute_proceed = asyncio.Event()
async def compute():
nonlocal compute_count
compute_count += 1
compute_started.set() # Signal that compute has started
await compute_proceed.wait() # Wait for permission to proceed
return tensor
# Start concurrent requests as tasks
task1 = asyncio.create_task(cache.get_or_compute("key1", compute))
task2 = asyncio.create_task(cache.get_or_compute("key1", compute))
task3 = asyncio.create_task(cache.get_or_compute("key1", compute))
# Wait for compute to start (ensures requests are queued)
await compute_started.wait()
# Allow compute to complete
compute_proceed.set()
results = await asyncio.gather(task1, task2, task3)
# All should get the same tensor
for result in results:
assert torch.equal(result, tensor)
# But compute should only be called once
assert compute_count == 1
@pytest.mark.asyncio
async def test_different_keys_compute_separately(self, cache):
"""Test that different keys compute independently."""
compute_count = 0
async def compute():
nonlocal compute_count
compute_count += 1
return torch.randn(10, 10)
await asyncio.gather(
cache.get_or_compute("key1", compute),
cache.get_or_compute("key2", compute),
cache.get_or_compute("key3", compute),
)
assert compute_count == 3
class TestAsyncEncoderCacheExceptionHandling:
"""Tests for exception handling."""
@pytest.fixture
def cache(self):
"""Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm)
@pytest.mark.asyncio
async def test_exception_propagates_to_caller(self, cache):
"""Test that compute exceptions propagate to the caller."""
async def compute():
raise ValueError("compute failed")
with pytest.raises(ValueError, match="compute failed"):
await cache.get_or_compute("key1", compute)
@pytest.mark.asyncio
async def test_exception_propagates_to_all_waiters(self, cache):
"""Test that compute exceptions propagate to all waiting coroutines."""
compute_started = asyncio.Event()
compute_proceed = asyncio.Event()
async def compute():
compute_started.set()
await compute_proceed.wait()
raise ValueError("compute failed")
# Start concurrent requests as tasks
task1 = asyncio.create_task(cache.get_or_compute("key1", compute))
task2 = asyncio.create_task(cache.get_or_compute("key1", compute))
# Wait for compute to start
await compute_started.wait()
# Allow compute to proceed (and fail)
compute_proceed.set()
# Gather with return_exceptions=True to capture all results
results = await asyncio.gather(task1, task2, return_exceptions=True)
# Verify ALL tasks got the exception
assert len(results) == 2
for result in results:
assert isinstance(result, ValueError)
assert str(result) == "compute failed"
@pytest.mark.asyncio
async def test_in_flight_cleared_after_exception(self, cache):
"""Test that in_flight is cleared after an exception."""
async def failing_compute():
raise ValueError("compute failed")
with pytest.raises(ValueError):
await cache.get_or_compute("key1", failing_compute)
# in_flight should be empty
assert len(cache._in_flight) == 0
# Should be able to retry
tensor = torch.randn(10, 10)
async def working_compute():
return tensor
result = await cache.get_or_compute("key1", working_compute)
assert torch.equal(result, tensor)
class TestAsyncEncoderCacheStats:
"""Tests for statistics."""
@pytest.fixture
def cache(self):
"""Create a cache for testing."""
ecm = EncoderCacheManager(capacity_bytes=1024 * 1024)
return AsyncEncoderCache(ecm)
def test_stats_includes_in_flight(self, cache):
"""Test that stats include in_flight count."""
stats = cache.stats
assert "in_flight" in stats
assert stats["in_flight"] == 0
@pytest.mark.asyncio
async def test_stats_reflects_underlying_cache(self, cache):
"""Test that stats reflect underlying cache state."""
tensor = torch.randn(10, 10)
async def compute():
return tensor
await cache.get_or_compute("key1", compute)
stats = cache.stats
assert stats["entries"] == 1
assert (
stats["hits"] == 0
) # get_or_compute checks cache but we track differently
assert stats["in_flight"] == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment