Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
838ba140
Unverified
Commit
838ba140
authored
Jan 28, 2026
by
Qi Wang
Committed by
GitHub
Jan 28, 2026
Browse files
feat: async encoder cache impl (#5676)
parent
7d3c67f0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
379 additions
and
0 deletions
+379
-0
components/src/dynamo/common/multimodal/__init__.py
components/src/dynamo/common/multimodal/__init__.py
+8
-0
components/src/dynamo/common/multimodal/async_encoder_cache.py
...nents/src/dynamo/common/multimodal/async_encoder_cache.py
+135
-0
components/src/dynamo/common/tests/multimodal/test_async_encoder_cache.py
...ynamo/common/tests/multimodal/test_async_encoder_cache.py
+236
-0
No files found.
components/src/dynamo/common/multimodal/__init__.py
0 → 100644
View file @
838ba140
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Multimodal utilities for Dynamo components."""
from
dynamo.common.multimodal.async_encoder_cache
import
AsyncEncoderCache
__all__
=
[
"AsyncEncoderCache"
]
components/src/dynamo/common/multimodal/async_encoder_cache.py
0 → 100644
View file @
838ba140
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Async Encoder Cache
Async wrapper over EncoderCacheManager with request coalescing.
Prevents duplicate encoding when multiple requests arrive for the same content.
Usage:
cache = EncoderCacheManager(capacity_bytes=4 * 1024**3)
async_cache = AsyncEncoderCache(cache)
# Get from cache or compute with coalescing
tensor = await async_cache.get_or_compute("hash123", encoder.encode)
"""
import
asyncio
import
logging
from
typing
import
Awaitable
,
Callable
,
Dict
,
Optional
import
torch
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
logger
=
logging
.
getLogger
(
__name__
)
def
_suppress_unhandled_future_exception
(
future
:
asyncio
.
Future
)
->
None
:
"""
Callback to prevent 'Future exception was never retrieved' warning.
When a Future has set_exception() called but no one awaits it (e.g., single
caller that gets the exception via re-raise), asyncio warns. This callback
retrieves the exception to suppress that warning.
"""
if
future
.
done
()
and
not
future
.
cancelled
():
try
:
future
.
exception
()
except
asyncio
.
CancelledError
:
pass
class
AsyncEncoderCache
:
"""
Async wrapper with request coalescing over EncoderCacheManager.
Provides async get_or_compute that deduplicates concurrent requests
for the same key, ensuring only one encoding runs at a time per key.
Thread Safety:
This class is NOT thread-safe. It is designed to run within a single
asyncio event loop. All access must be from the same thread.
"""
def
__init__
(
self
,
cache
:
EncoderCacheManager
):
"""
Initialize the async encoder cache.
Args:
cache: Underlying EncoderCacheManager for storage.
"""
self
.
_cache
=
cache
self
.
_in_flight
:
Dict
[
str
,
asyncio
.
Future
[
torch
.
Tensor
]]
=
{}
def
get
(
self
,
key
:
str
)
->
Optional
[
torch
.
Tensor
]:
"""
Synchronous get from underlying cache.
Args:
key: Cache key.
Returns:
Cached tensor or None if not found.
"""
return
self
.
_cache
.
get
(
key
)
async
def
get_or_compute
(
self
,
key
:
str
,
compute_fn
:
Callable
[[],
Awaitable
[
torch
.
Tensor
]],
)
->
torch
.
Tensor
:
"""
Get from cache or compute with request coalescing.
If the key is in cache, returns immediately.
If another coroutine is already computing this key, waits for that result.
Otherwise, computes and caches the result.
Args:
key: Cache key (typically content hash).
compute_fn: Async function to compute the tensor if not cached.
Returns:
The cached or computed tensor.
Raises:
Exception: Re-raises any exception from compute_fn.
"""
# Check cache first
cached
=
self
.
_cache
.
get
(
key
)
if
cached
is
not
None
:
return
cached
# Wait if already in-flight
if
key
in
self
.
_in_flight
:
logger
.
debug
(
f
"Waiting for in-flight computation: key=
{
key
[:
16
]
}
..."
)
return
await
self
.
_in_flight
[
key
]
# Compute with coalescing
future
:
asyncio
.
Future
[
torch
.
Tensor
]
=
asyncio
.
Future
()
future
.
add_done_callback
(
_suppress_unhandled_future_exception
)
self
.
_in_flight
[
key
]
=
future
try
:
tensor
=
await
compute_fn
()
self
.
_cache
.
set
(
key
,
tensor
)
future
.
set_result
(
tensor
)
return
tensor
except
Exception
as
e
:
future
.
set_exception
(
e
)
raise
finally
:
del
self
.
_in_flight
[
key
]
@
property
def
stats
(
self
)
->
dict
:
"""
Get cache statistics from underlying cache.
Returns:
Dictionary with cache stats.
"""
base_stats
=
self
.
_cache
.
stats
base_stats
[
"in_flight"
]
=
len
(
self
.
_in_flight
)
return
base_stats
components/src/dynamo/common/tests/multimodal/test_async_encoder_cache.py
0 → 100644
View file @
838ba140
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for AsyncEncoderCache."""
import
asyncio
import
pytest
import
torch
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
from
dynamo.common.multimodal.async_encoder_cache
import
AsyncEncoderCache
class
TestAsyncEncoderCacheBasicOperations
:
"""Tests for basic operations."""
@
pytest
.
fixture
def
cache
(
self
):
"""Create a cache for testing."""
ecm
=
EncoderCacheManager
(
capacity_bytes
=
1024
*
1024
)
return
AsyncEncoderCache
(
ecm
)
def
test_sync_get_returns_none_for_missing_key
(
self
,
cache
):
"""Test sync get returns None for nonexistent key."""
assert
cache
.
get
(
"nonexistent"
)
is
None
def
test_sync_get_returns_cached_tensor
(
self
,
cache
):
"""Test sync get returns tensor after it's cached."""
tensor
=
torch
.
randn
(
10
,
10
)
cache
.
_cache
.
set
(
"key1"
,
tensor
)
result
=
cache
.
get
(
"key1"
)
assert
torch
.
equal
(
result
,
tensor
)
@
pytest
.
mark
.
asyncio
async
def
test_get_or_compute_caches_result
(
self
,
cache
):
"""Test get_or_compute caches the computed result."""
tensor
=
torch
.
randn
(
10
,
10
)
async
def
compute
():
return
tensor
result
=
await
cache
.
get_or_compute
(
"key1"
,
compute
)
assert
torch
.
equal
(
result
,
tensor
)
# Should be in cache now
cached
=
cache
.
get
(
"key1"
)
assert
cached
is
not
None
assert
torch
.
equal
(
cached
,
tensor
)
@
pytest
.
mark
.
asyncio
async
def
test_get_or_compute_returns_cached
(
self
,
cache
):
"""Test get_or_compute returns cached value without computing."""
tensor
=
torch
.
randn
(
10
,
10
)
cache
.
_cache
.
set
(
"key1"
,
tensor
)
compute_called
=
False
async
def
compute
():
nonlocal
compute_called
compute_called
=
True
return
torch
.
randn
(
10
,
10
)
result
=
await
cache
.
get_or_compute
(
"key1"
,
compute
)
assert
torch
.
equal
(
result
,
tensor
)
assert
not
compute_called
class
TestAsyncEncoderCacheRequestCoalescing
:
"""Tests for request coalescing behavior."""
@
pytest
.
fixture
def
cache
(
self
):
"""Create a cache for testing."""
ecm
=
EncoderCacheManager
(
capacity_bytes
=
1024
*
1024
)
return
AsyncEncoderCache
(
ecm
)
@
pytest
.
mark
.
asyncio
async
def
test_concurrent_requests_coalesce
(
self
,
cache
):
"""Test that concurrent requests for same key only compute once."""
compute_count
=
0
tensor
=
torch
.
randn
(
10
,
10
)
compute_started
=
asyncio
.
Event
()
compute_proceed
=
asyncio
.
Event
()
async
def
compute
():
nonlocal
compute_count
compute_count
+=
1
compute_started
.
set
()
# Signal that compute has started
await
compute_proceed
.
wait
()
# Wait for permission to proceed
return
tensor
# Start concurrent requests as tasks
task1
=
asyncio
.
create_task
(
cache
.
get_or_compute
(
"key1"
,
compute
))
task2
=
asyncio
.
create_task
(
cache
.
get_or_compute
(
"key1"
,
compute
))
task3
=
asyncio
.
create_task
(
cache
.
get_or_compute
(
"key1"
,
compute
))
# Wait for compute to start (ensures requests are queued)
await
compute_started
.
wait
()
# Allow compute to complete
compute_proceed
.
set
()
results
=
await
asyncio
.
gather
(
task1
,
task2
,
task3
)
# All should get the same tensor
for
result
in
results
:
assert
torch
.
equal
(
result
,
tensor
)
# But compute should only be called once
assert
compute_count
==
1
@
pytest
.
mark
.
asyncio
async
def
test_different_keys_compute_separately
(
self
,
cache
):
"""Test that different keys compute independently."""
compute_count
=
0
async
def
compute
():
nonlocal
compute_count
compute_count
+=
1
return
torch
.
randn
(
10
,
10
)
await
asyncio
.
gather
(
cache
.
get_or_compute
(
"key1"
,
compute
),
cache
.
get_or_compute
(
"key2"
,
compute
),
cache
.
get_or_compute
(
"key3"
,
compute
),
)
assert
compute_count
==
3
class
TestAsyncEncoderCacheExceptionHandling
:
"""Tests for exception handling."""
@
pytest
.
fixture
def
cache
(
self
):
"""Create a cache for testing."""
ecm
=
EncoderCacheManager
(
capacity_bytes
=
1024
*
1024
)
return
AsyncEncoderCache
(
ecm
)
@
pytest
.
mark
.
asyncio
async
def
test_exception_propagates_to_caller
(
self
,
cache
):
"""Test that compute exceptions propagate to the caller."""
async
def
compute
():
raise
ValueError
(
"compute failed"
)
with
pytest
.
raises
(
ValueError
,
match
=
"compute failed"
):
await
cache
.
get_or_compute
(
"key1"
,
compute
)
@
pytest
.
mark
.
asyncio
async
def
test_exception_propagates_to_all_waiters
(
self
,
cache
):
"""Test that compute exceptions propagate to all waiting coroutines."""
compute_started
=
asyncio
.
Event
()
compute_proceed
=
asyncio
.
Event
()
async
def
compute
():
compute_started
.
set
()
await
compute_proceed
.
wait
()
raise
ValueError
(
"compute failed"
)
# Start concurrent requests as tasks
task1
=
asyncio
.
create_task
(
cache
.
get_or_compute
(
"key1"
,
compute
))
task2
=
asyncio
.
create_task
(
cache
.
get_or_compute
(
"key1"
,
compute
))
# Wait for compute to start
await
compute_started
.
wait
()
# Allow compute to proceed (and fail)
compute_proceed
.
set
()
# Gather with return_exceptions=True to capture all results
results
=
await
asyncio
.
gather
(
task1
,
task2
,
return_exceptions
=
True
)
# Verify ALL tasks got the exception
assert
len
(
results
)
==
2
for
result
in
results
:
assert
isinstance
(
result
,
ValueError
)
assert
str
(
result
)
==
"compute failed"
@
pytest
.
mark
.
asyncio
async
def
test_in_flight_cleared_after_exception
(
self
,
cache
):
"""Test that in_flight is cleared after an exception."""
async
def
failing_compute
():
raise
ValueError
(
"compute failed"
)
with
pytest
.
raises
(
ValueError
):
await
cache
.
get_or_compute
(
"key1"
,
failing_compute
)
# in_flight should be empty
assert
len
(
cache
.
_in_flight
)
==
0
# Should be able to retry
tensor
=
torch
.
randn
(
10
,
10
)
async
def
working_compute
():
return
tensor
result
=
await
cache
.
get_or_compute
(
"key1"
,
working_compute
)
assert
torch
.
equal
(
result
,
tensor
)
class
TestAsyncEncoderCacheStats
:
"""Tests for statistics."""
@
pytest
.
fixture
def
cache
(
self
):
"""Create a cache for testing."""
ecm
=
EncoderCacheManager
(
capacity_bytes
=
1024
*
1024
)
return
AsyncEncoderCache
(
ecm
)
def
test_stats_includes_in_flight
(
self
,
cache
):
"""Test that stats include in_flight count."""
stats
=
cache
.
stats
assert
"in_flight"
in
stats
assert
stats
[
"in_flight"
]
==
0
@
pytest
.
mark
.
asyncio
async
def
test_stats_reflects_underlying_cache
(
self
,
cache
):
"""Test that stats reflect underlying cache state."""
tensor
=
torch
.
randn
(
10
,
10
)
async
def
compute
():
return
tensor
await
cache
.
get_or_compute
(
"key1"
,
compute
)
stats
=
cache
.
stats
assert
stats
[
"entries"
]
==
1
assert
(
stats
[
"hits"
]
==
0
)
# get_or_compute checks cache but we track differently
assert
stats
[
"in_flight"
]
==
0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment