Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
c82fe888
Unverified
Commit
c82fe888
authored
Feb 20, 2026
by
Qi Wang
Committed by
GitHub
Feb 20, 2026
Browse files
feat: add embedding cache to pd worker (#6061)
parent
ebc61637
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
753 additions
and
261 deletions
+753
-261
components/src/dynamo/common/memory/multimodal_embedding_cache_manager.py
...ynamo/common/memory/multimodal_embedding_cache_manager.py
+21
-16
components/src/dynamo/common/tests/memory/test_multimodal_embedding_cache_manager.py
...n/tests/memory/test_multimodal_embedding_cache_manager.py
+60
-22
components/src/dynamo/trtllm/multimodal/embedding_fetcher.py
components/src/dynamo/trtllm/multimodal/embedding_fetcher.py
+3
-2
components/src/dynamo/trtllm/tests/multimodal/test_trtllm_embedding_fetcher.py
.../trtllm/tests/multimodal/test_trtllm_embedding_fetcher.py
+13
-3
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
.../vllm/multimodal_handlers/multimodal_pd_worker_handler.py
+37
-116
components/src/dynamo/vllm/multimodal_utils/__init__.py
components/src/dynamo/vllm/multimodal_utils/__init__.py
+2
-8
components/src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
.../src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
+190
-37
components/src/dynamo/vllm/multimodal_utils/protocol.py
components/src/dynamo/vllm/multimodal_utils/protocol.py
+2
-0
components/src/dynamo/vllm/tests/multimodal_handlers/test_vllm_multimodal_pd_worker_handler.py
...imodal_handlers/test_vllm_multimodal_pd_worker_handler.py
+88
-19
components/src/dynamo/vllm/tests/multimodal_utils/test_vllm_prefill_worker_utils.py
.../tests/multimodal_utils/test_vllm_prefill_worker_utils.py
+162
-0
examples/backends/vllm/launch/disagg_multimodal_e_pd.sh
examples/backends/vllm/launch/disagg_multimodal_e_pd.sh
+119
-0
tests/serve/conftest.py
tests/serve/conftest.py
+25
-3
tests/serve/test_vllm.py
tests/serve/test_vllm.py
+31
-35
No files found.
components/src/dynamo/common/memory/multimodal_embedding_cache_manager.py
View file @
c82fe888
...
@@ -19,13 +19,18 @@ Usage:
...
@@ -19,13 +19,18 @@ Usage:
import
logging
import
logging
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Optional
from
typing
import
NamedTuple
,
Optional
import
torch
import
torch
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
CachedEmbedding
(
NamedTuple
):
tensor
:
torch
.
Tensor
image_grid_thw
:
list
|
None
=
None
class
MultimodalEmbeddingCacheManager
:
class
MultimodalEmbeddingCacheManager
:
"""
"""
LRU cache for encoder embeddings.
LRU cache for encoder embeddings.
...
@@ -47,7 +52,7 @@ class MultimodalEmbeddingCacheManager:
...
@@ -47,7 +52,7 @@ class MultimodalEmbeddingCacheManager:
Args:
Args:
capacity_bytes: Maximum cache capacity in bytes.
capacity_bytes: Maximum cache capacity in bytes.
"""
"""
self
.
_cache
:
OrderedDict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
self
.
_cache
:
OrderedDict
[
str
,
CachedEmbedding
]
=
OrderedDict
()
self
.
_capacity_bytes
=
capacity_bytes
self
.
_capacity_bytes
=
capacity_bytes
self
.
_current_bytes
=
0
self
.
_current_bytes
=
0
...
@@ -77,9 +82,9 @@ class MultimodalEmbeddingCacheManager:
...
@@ -77,9 +82,9 @@ class MultimodalEmbeddingCacheManager:
),
"Tensor must be contiguous for accurate size calculation"
),
"Tensor must be contiguous for accurate size calculation"
return
tensor
.
element_size
()
*
tensor
.
numel
()
return
tensor
.
element_size
()
*
tensor
.
numel
()
def
get
(
self
,
key
:
str
)
->
Optional
[
torch
.
Tensor
]:
def
get
(
self
,
key
:
str
)
->
Optional
[
CachedEmbedding
]:
"""
"""
Get a
tensor
from the cache.
Get a
cached embedding
from the cache.
If found, the entry is moved to the end (most recently used).
If found, the entry is moved to the end (most recently used).
...
@@ -87,7 +92,7 @@ class MultimodalEmbeddingCacheManager:
...
@@ -87,7 +92,7 @@ class MultimodalEmbeddingCacheManager:
key: Cache key (typically content hash).
key: Cache key (typically content hash).
Returns:
Returns:
The cached
tensor
, or None if not found.
The cached
embedding
, or None if not found.
"""
"""
if
key
not
in
self
.
_cache
:
if
key
not
in
self
.
_cache
:
self
.
_misses
+=
1
self
.
_misses
+=
1
...
@@ -98,22 +103,22 @@ class MultimodalEmbeddingCacheManager:
...
@@ -98,22 +103,22 @@ class MultimodalEmbeddingCacheManager:
self
.
_hits
+=
1
self
.
_hits
+=
1
return
self
.
_cache
[
key
]
return
self
.
_cache
[
key
]
def
set
(
self
,
key
:
str
,
t
en
sor
:
torch
.
Tensor
)
->
bool
:
def
set
(
self
,
key
:
str
,
en
try
:
CachedEmbedding
)
->
bool
:
"""
"""
Store a
tensor
in the cache.
Store a
cached embedding
in the cache.
If the key already exists, the old value is replaced.
If the key already exists, the old value is replaced.
If adding the
t
en
sor
would exceed capacity, LRU entries are evicted.
If adding the en
try
would exceed capacity, LRU entries are evicted.
If the tensor itself is larger than capacity, it is not stored.
If the tensor itself is larger than capacity, it is not stored.
Args:
Args:
key: Cache key (typically content hash).
key: Cache key (typically content hash).
t
en
sor: Tensor
to cache.
en
try: CachedEmbedding
to cache.
Returns:
Returns:
True if the
t
en
sor
was stored, False if it was too large.
True if the en
try
was stored, False if it was too large.
"""
"""
size
=
self
.
_tensor_size
(
tensor
)
size
=
self
.
_tensor_size
(
entry
.
tensor
)
# Don't cache if single tensor exceeds capacity
# Don't cache if single tensor exceeds capacity
if
size
>
self
.
_capacity_bytes
:
if
size
>
self
.
_capacity_bytes
:
...
@@ -125,20 +130,20 @@ class MultimodalEmbeddingCacheManager:
...
@@ -125,20 +130,20 @@ class MultimodalEmbeddingCacheManager:
# If key exists, remove old entry first
# If key exists, remove old entry first
if
key
in
self
.
_cache
:
if
key
in
self
.
_cache
:
old_
t
en
sor
=
self
.
_cache
.
pop
(
key
)
old_en
try
=
self
.
_cache
.
pop
(
key
)
self
.
_current_bytes
-=
self
.
_tensor_size
(
old_tensor
)
self
.
_current_bytes
-=
self
.
_tensor_size
(
old_
entry
.
tensor
)
# Evict LRU entries until we have space
# Evict LRU entries until we have space
while
self
.
_current_bytes
+
size
>
self
.
_capacity_bytes
and
self
.
_cache
:
while
self
.
_current_bytes
+
size
>
self
.
_capacity_bytes
and
self
.
_cache
:
evicted_key
,
evicted_
t
en
sor
=
self
.
_cache
.
popitem
(
last
=
False
)
evicted_key
,
evicted_en
try
=
self
.
_cache
.
popitem
(
last
=
False
)
evicted_size
=
self
.
_tensor_size
(
evicted_tensor
)
evicted_size
=
self
.
_tensor_size
(
evicted_
entry
.
tensor
)
self
.
_current_bytes
-=
evicted_size
self
.
_current_bytes
-=
evicted_size
logger
.
debug
(
logger
.
debug
(
f
"Evicted key=
{
evicted_key
[:
16
]
}
..., size=
{
evicted_size
/
1024
**
2
:.
2
f
}
MB"
f
"Evicted key=
{
evicted_key
[:
16
]
}
..., size=
{
evicted_size
/
1024
**
2
:.
2
f
}
MB"
)
)
# Store new entry
# Store new entry
self
.
_cache
[
key
]
=
t
en
sor
self
.
_cache
[
key
]
=
en
try
self
.
_current_bytes
+=
size
self
.
_current_bytes
+=
size
logger
.
debug
(
logger
.
debug
(
...
...
components/src/dynamo/common/tests/memory/test_multimodal_embedding_cache_manager.py
View file @
c82fe888
...
@@ -7,6 +7,7 @@ import pytest
...
@@ -7,6 +7,7 @@ import pytest
import
torch
import
torch
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
CachedEmbedding
,
MultimodalEmbeddingCacheManager
,
MultimodalEmbeddingCacheManager
,
)
)
...
@@ -19,12 +20,25 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations:
...
@@ -19,12 +20,25 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations:
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
1024
*
1024
)
# 1MB
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
1024
*
1024
)
# 1MB
tensor
=
torch
.
randn
(
100
,
100
)
# ~40KB for float32
tensor
=
torch
.
randn
(
100
,
100
)
# ~40KB for float32
result
=
cache
.
set
(
"key1"
,
tensor
)
result
=
cache
.
set
(
"key1"
,
CachedEmbedding
(
tensor
)
)
assert
result
is
True
assert
result
is
True
retrieved
=
cache
.
get
(
"key1"
)
retrieved
=
cache
.
get
(
"key1"
)
assert
retrieved
is
not
None
assert
retrieved
is
not
None
assert
torch
.
equal
(
retrieved
,
tensor
)
assert
torch
.
equal
(
retrieved
.
tensor
,
tensor
)
assert
retrieved
.
image_grid_thw
is
None
def
test_set_and_get_with_grid
(
self
):
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
1024
*
1024
)
tensor
=
torch
.
randn
(
100
,
100
)
grid
=
[[
1
,
2
,
3
]]
cache
.
set
(
"key1"
,
CachedEmbedding
(
tensor
,
grid
))
retrieved
=
cache
.
get
(
"key1"
)
assert
retrieved
is
not
None
assert
torch
.
equal
(
retrieved
.
tensor
,
tensor
)
assert
retrieved
.
image_grid_thw
==
grid
def
test_get_nonexistent_key
(
self
):
def
test_get_nonexistent_key
(
self
):
"""Test get returns None for nonexistent key."""
"""Test get returns None for nonexistent key."""
...
@@ -39,11 +53,11 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations:
...
@@ -39,11 +53,11 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations:
tensor1
=
torch
.
randn
(
10
,
10
)
tensor1
=
torch
.
randn
(
10
,
10
)
tensor2
=
torch
.
randn
(
10
,
10
)
tensor2
=
torch
.
randn
(
10
,
10
)
cache
.
set
(
"key1"
,
tensor1
)
cache
.
set
(
"key1"
,
CachedEmbedding
(
tensor1
)
)
cache
.
set
(
"key1"
,
tensor2
)
cache
.
set
(
"key1"
,
CachedEmbedding
(
tensor2
)
)
retrieved
=
cache
.
get
(
"key1"
)
retrieved
=
cache
.
get
(
"key1"
)
assert
torch
.
equal
(
retrieved
,
tensor2
)
assert
torch
.
equal
(
retrieved
.
tensor
,
tensor2
)
assert
cache
.
stats
[
"entries"
]
==
1
assert
cache
.
stats
[
"entries"
]
==
1
...
@@ -61,11 +75,11 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
...
@@ -61,11 +75,11 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
t2
=
torch
.
randn
(
10
,
10
)
t2
=
torch
.
randn
(
10
,
10
)
t3
=
torch
.
randn
(
10
,
10
)
t3
=
torch
.
randn
(
10
,
10
)
cache
.
set
(
"key1"
,
t1
)
cache
.
set
(
"key1"
,
CachedEmbedding
(
t1
)
)
cache
.
set
(
"key2"
,
t2
)
cache
.
set
(
"key2"
,
CachedEmbedding
(
t2
)
)
# Adding third should evict first (LRU)
# Adding third should evict first (LRU)
cache
.
set
(
"key3"
,
t3
)
cache
.
set
(
"key3"
,
CachedEmbedding
(
t3
)
)
assert
cache
.
get
(
"key1"
)
is
None
# Evicted
assert
cache
.
get
(
"key1"
)
is
None
# Evicted
assert
cache
.
get
(
"key2"
)
is
not
None
assert
cache
.
get
(
"key2"
)
is
not
None
...
@@ -81,14 +95,14 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
...
@@ -81,14 +95,14 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
t2
=
torch
.
randn
(
10
,
10
)
t2
=
torch
.
randn
(
10
,
10
)
t3
=
torch
.
randn
(
10
,
10
)
t3
=
torch
.
randn
(
10
,
10
)
cache
.
set
(
"key1"
,
t1
)
cache
.
set
(
"key1"
,
CachedEmbedding
(
t1
)
)
cache
.
set
(
"key2"
,
t2
)
cache
.
set
(
"key2"
,
CachedEmbedding
(
t2
)
)
# Access key1, making key2 the LRU
# Access key1, making key2 the LRU
cache
.
get
(
"key1"
)
cache
.
get
(
"key1"
)
# Adding third should evict key2 (now LRU)
# Adding third should evict key2 (now LRU)
cache
.
set
(
"key3"
,
t3
)
cache
.
set
(
"key3"
,
CachedEmbedding
(
t3
)
)
assert
cache
.
get
(
"key1"
)
is
not
None
# Not evicted (recently accessed)
assert
cache
.
get
(
"key1"
)
is
not
None
# Not evicted (recently accessed)
assert
cache
.
get
(
"key2"
)
is
None
# Evicted (LRU)
assert
cache
.
get
(
"key2"
)
is
None
# Evicted (LRU)
...
@@ -99,7 +113,7 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
...
@@ -99,7 +113,7 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
100
)
# Very small
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
100
)
# Very small
tensor
=
torch
.
randn
(
100
,
100
)
# ~40KB, way larger than capacity
tensor
=
torch
.
randn
(
100
,
100
)
# ~40KB, way larger than capacity
result
=
cache
.
set
(
"key1"
,
tensor
)
result
=
cache
.
set
(
"key1"
,
CachedEmbedding
(
tensor
)
)
assert
result
is
False
assert
result
is
False
assert
cache
.
get
(
"key1"
)
is
None
assert
cache
.
get
(
"key1"
)
is
None
...
@@ -119,10 +133,10 @@ class TestMultimodalEmbeddingCacheManagerSizeTracking:
...
@@ -119,10 +133,10 @@ class TestMultimodalEmbeddingCacheManagerSizeTracking:
expected_size_1
=
t1
.
element_size
()
*
t1
.
numel
()
expected_size_1
=
t1
.
element_size
()
*
t1
.
numel
()
expected_size_2
=
t2
.
element_size
()
*
t2
.
numel
()
expected_size_2
=
t2
.
element_size
()
*
t2
.
numel
()
cache
.
set
(
"key1"
,
t1
)
cache
.
set
(
"key1"
,
CachedEmbedding
(
t1
)
)
assert
cache
.
stats
[
"current_bytes"
]
==
expected_size_1
assert
cache
.
stats
[
"current_bytes"
]
==
expected_size_1
cache
.
set
(
"key2"
,
t2
)
cache
.
set
(
"key2"
,
CachedEmbedding
(
t2
)
)
assert
cache
.
stats
[
"current_bytes"
]
==
expected_size_1
+
expected_size_2
assert
cache
.
stats
[
"current_bytes"
]
==
expected_size_1
+
expected_size_2
def
test_size_updated_on_overwrite
(
self
):
def
test_size_updated_on_overwrite
(
self
):
...
@@ -132,10 +146,10 @@ class TestMultimodalEmbeddingCacheManagerSizeTracking:
...
@@ -132,10 +146,10 @@ class TestMultimodalEmbeddingCacheManagerSizeTracking:
small_tensor
=
torch
.
randn
(
10
,
10
)
# 400 bytes
small_tensor
=
torch
.
randn
(
10
,
10
)
# 400 bytes
large_tensor
=
torch
.
randn
(
20
,
20
)
# 1600 bytes
large_tensor
=
torch
.
randn
(
20
,
20
)
# 1600 bytes
cache
.
set
(
"key1"
,
small_tensor
)
cache
.
set
(
"key1"
,
CachedEmbedding
(
small_tensor
)
)
initial_size
=
cache
.
stats
[
"current_bytes"
]
initial_size
=
cache
.
stats
[
"current_bytes"
]
cache
.
set
(
"key1"
,
large_tensor
)
cache
.
set
(
"key1"
,
CachedEmbedding
(
large_tensor
)
)
expected_size
=
large_tensor
.
element_size
()
*
large_tensor
.
numel
()
expected_size
=
large_tensor
.
element_size
()
*
large_tensor
.
numel
()
assert
cache
.
stats
[
"current_bytes"
]
==
expected_size
assert
cache
.
stats
[
"current_bytes"
]
==
expected_size
...
@@ -150,7 +164,7 @@ class TestMultimodalEmbeddingCacheManagerStats:
...
@@ -150,7 +164,7 @@ class TestMultimodalEmbeddingCacheManagerStats:
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
1024
*
1024
)
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
1024
*
1024
)
tensor
=
torch
.
randn
(
10
,
10
)
tensor
=
torch
.
randn
(
10
,
10
)
cache
.
set
(
"key1"
,
tensor
)
cache
.
set
(
"key1"
,
CachedEmbedding
(
tensor
)
)
# Misses
# Misses
cache
.
get
(
"nonexistent1"
)
cache
.
get
(
"nonexistent1"
)
...
@@ -170,7 +184,7 @@ class TestMultimodalEmbeddingCacheManagerStats:
...
@@ -170,7 +184,7 @@ class TestMultimodalEmbeddingCacheManagerStats:
"""Test stats dictionary contains expected keys."""
"""Test stats dictionary contains expected keys."""
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
1024
*
1024
)
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
1024
*
1024
)
tensor
=
torch
.
randn
(
10
,
10
)
tensor
=
torch
.
randn
(
10
,
10
)
cache
.
set
(
"key1"
,
tensor
)
cache
.
set
(
"key1"
,
CachedEmbedding
(
tensor
)
)
stats
=
cache
.
stats
stats
=
cache
.
stats
...
@@ -190,10 +204,9 @@ class TestMultimodalEmbeddingCacheManagerStats:
...
@@ -190,10 +204,9 @@ class TestMultimodalEmbeddingCacheManagerStats:
capacity
=
1000
capacity
=
1000
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
capacity
)
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
capacity
)
# Create tensor of known size
# float32 = 4 bytes, so 25 elements = 100 bytes
# float32 = 4 bytes, so 25 elements = 100 bytes
tensor
=
torch
.
zeros
(
25
,
dtype
=
torch
.
float32
)
tensor
=
torch
.
zeros
(
25
,
dtype
=
torch
.
float32
)
cache
.
set
(
"key1"
,
tensor
)
cache
.
set
(
"key1"
,
CachedEmbedding
(
tensor
)
)
stats
=
cache
.
stats
stats
=
cache
.
stats
expected_utilization
=
100
/
capacity
expected_utilization
=
100
/
capacity
...
@@ -209,7 +222,7 @@ class TestMultimodalEmbeddingCacheManagerContiguousTensor:
...
@@ -209,7 +222,7 @@ class TestMultimodalEmbeddingCacheManagerContiguousTensor:
tensor
=
torch
.
randn
(
10
,
10
)
tensor
=
torch
.
randn
(
10
,
10
)
assert
tensor
.
is_contiguous
()
assert
tensor
.
is_contiguous
()
result
=
cache
.
set
(
"key1"
,
tensor
)
result
=
cache
.
set
(
"key1"
,
CachedEmbedding
(
tensor
)
)
assert
result
is
True
assert
result
is
True
def
test_set_non_contiguous_tensor_raises
(
self
):
def
test_set_non_contiguous_tensor_raises
(
self
):
...
@@ -221,4 +234,29 @@ class TestMultimodalEmbeddingCacheManagerContiguousTensor:
...
@@ -221,4 +234,29 @@ class TestMultimodalEmbeddingCacheManagerContiguousTensor:
assert
not
tensor
.
is_contiguous
()
assert
not
tensor
.
is_contiguous
()
with
pytest
.
raises
(
AssertionError
,
match
=
"Tensor must be contiguous"
):
with
pytest
.
raises
(
AssertionError
,
match
=
"Tensor must be contiguous"
):
cache
.
set
(
"key1"
,
tensor
)
cache
.
set
(
"key1"
,
CachedEmbedding
(
tensor
))
class
TestCachedEmbeddingNamedTuple
:
"""Tests for CachedEmbedding NamedTuple."""
def
test_fields
(
self
):
tensor
=
torch
.
randn
(
4
,
4
)
grid
=
[[
1
,
2
,
3
]]
entry
=
CachedEmbedding
(
tensor
=
tensor
,
image_grid_thw
=
grid
)
assert
torch
.
equal
(
entry
.
tensor
,
tensor
)
assert
entry
.
image_grid_thw
==
grid
def
test_none_grid
(
self
):
tensor
=
torch
.
randn
(
4
,
4
)
entry
=
CachedEmbedding
(
tensor
=
tensor
,
image_grid_thw
=
None
)
assert
entry
.
image_grid_thw
is
None
def
test_unpacking
(
self
):
tensor
=
torch
.
randn
(
4
,
4
)
grid
=
[[
1
,
2
,
3
]]
entry
=
CachedEmbedding
(
tensor
=
tensor
,
image_grid_thw
=
grid
)
t
,
g
=
entry
assert
torch
.
equal
(
t
,
tensor
)
assert
g
==
grid
components/src/dynamo/trtllm/multimodal/embedding_fetcher.py
View file @
c82fe888
...
@@ -15,6 +15,7 @@ import torch
...
@@ -15,6 +15,7 @@ import torch
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
CachedEmbedding
,
MultimodalEmbeddingCacheManager
,
MultimodalEmbeddingCacheManager
,
)
)
from
dynamo.trtllm.multimodal.cuda_ipc
import
extract_embeddings_from_handles
from
dynamo.trtllm.multimodal.cuda_ipc
import
extract_embeddings_from_handles
...
@@ -148,7 +149,7 @@ async def _fetch_embeddings_with_cache(
...
@@ -148,7 +149,7 @@ async def _fetch_embeddings_with_cache(
cached
=
cache
.
get
(
url_hash
)
cached
=
cache
.
get
(
url_hash
)
if
cached
is
not
None
:
if
cached
is
not
None
:
logger
.
info
(
f
"fetch_embeddings_with_cache: cache hit for URL:
{
url
}
"
)
logger
.
info
(
f
"fetch_embeddings_with_cache: cache hit for URL:
{
url
}
"
)
embeddings_with_index
.
append
((
i
,
cached
))
embeddings_with_index
.
append
((
i
,
cached
.
tensor
))
else
:
else
:
logger
.
info
(
f
"fetch_embeddings_with_cache: cache miss for URL:
{
url
}
"
)
logger
.
info
(
f
"fetch_embeddings_with_cache: cache miss for URL:
{
url
}
"
)
uncached_urls
.
append
(
url
)
uncached_urls
.
append
(
url
)
...
@@ -189,7 +190,7 @@ async def _fetch_embeddings_with_cache(
...
@@ -189,7 +190,7 @@ async def _fetch_embeddings_with_cache(
# Cache new tensors (reuse hashes computed during cache lookup)
# Cache new tensors (reuse hashes computed during cache lookup)
for
url
,
url_hash
,
tensor
in
zip
(
uncached_urls
,
uncached_hashes
,
new_tensors
):
for
url
,
url_hash
,
tensor
in
zip
(
uncached_urls
,
uncached_hashes
,
new_tensors
):
cache
.
set
(
url_hash
,
tensor
)
cache
.
set
(
url_hash
,
CachedEmbedding
(
tensor
=
tensor
)
)
logger
.
info
(
logger
.
info
(
f
"fetch_embeddings_with_cache: cached embedding for URL:
{
url
}
, shape:
{
tensor
.
shape
}
"
f
"fetch_embeddings_with_cache: cached embedding for URL:
{
url
}
, shape:
{
tensor
.
shape
}
"
)
)
...
...
components/src/dynamo/trtllm/tests/multimodal/test_trtllm_embedding_fetcher.py
View file @
c82fe888
...
@@ -18,6 +18,7 @@ if not torch.cuda.is_available():
...
@@ -18,6 +18,7 @@ if not torch.cuda.is_available():
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
CachedEmbedding
,
MultimodalEmbeddingCacheManager
,
MultimodalEmbeddingCacheManager
,
)
)
from
dynamo.trtllm.multimodal.embedding_fetcher
import
fetch_embeddings_from_encoder
from
dynamo.trtllm.multimodal.embedding_fetcher
import
fetch_embeddings_from_encoder
...
@@ -76,7 +77,10 @@ class TestFetchEmbeddingsFromEncoder:
...
@@ -76,7 +77,10 @@ class TestFetchEmbeddingsFromEncoder:
url1
,
url2
=
"http://example.com/img1.jpg"
,
"http://example.com/img2.jpg"
url1
,
url2
=
"http://example.com/img1.jpg"
,
"http://example.com/img2.jpg"
embedding1
,
embedding2
=
torch
.
ones
(
10
,
256
),
torch
.
ones
(
10
,
256
)
*
2
embedding1
,
embedding2
=
torch
.
ones
(
10
,
256
),
torch
.
ones
(
10
,
256
)
*
2
encoder_cache
.
set
(
MultimodalHasher
.
hash_bytes
(
url1
.
encode
()),
embedding1
)
encoder_cache
.
set
(
MultimodalHasher
.
hash_bytes
(
url1
.
encode
()),
CachedEmbedding
(
tensor
=
embedding1
),
)
request
:
dict
[
str
,
Any
]
=
{
"messages"
:
[]}
request
:
dict
[
str
,
Any
]
=
{
"messages"
:
[]}
mock_client
=
create_mock_encode_client
([
embedding2
])
mock_client
=
create_mock_encode_client
([
embedding2
])
...
@@ -98,8 +102,14 @@ class TestFetchEmbeddingsFromEncoder:
...
@@ -98,8 +102,14 @@ class TestFetchEmbeddingsFromEncoder:
url1
,
url2
=
"http://example.com/img1.jpg"
,
"http://example.com/img2.jpg"
url1
,
url2
=
"http://example.com/img1.jpg"
,
"http://example.com/img2.jpg"
embedding1
,
embedding2
=
torch
.
ones
(
10
,
256
),
torch
.
ones
(
10
,
256
)
*
2
embedding1
,
embedding2
=
torch
.
ones
(
10
,
256
),
torch
.
ones
(
10
,
256
)
*
2
encoder_cache
.
set
(
MultimodalHasher
.
hash_bytes
(
url1
.
encode
()),
embedding1
)
encoder_cache
.
set
(
encoder_cache
.
set
(
MultimodalHasher
.
hash_bytes
(
url2
.
encode
()),
embedding2
)
MultimodalHasher
.
hash_bytes
(
url1
.
encode
()),
CachedEmbedding
(
tensor
=
embedding1
),
)
encoder_cache
.
set
(
MultimodalHasher
.
hash_bytes
(
url2
.
encode
()),
CachedEmbedding
(
tensor
=
embedding2
),
)
async
def
should_not_call
(
req
:
dict
[
str
,
Any
])
->
None
:
async
def
should_not_call
(
req
:
dict
[
str
,
Any
])
->
None
:
raise
AssertionError
(
"Should not be called"
)
raise
AssertionError
(
"Should not be called"
)
...
...
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
View file @
c82fe888
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
copy
import
copy
import
logging
import
logging
import
os
import
os
...
@@ -26,21 +25,16 @@ from dynamo.runtime import Client, Component, DistributedRuntime
...
@@ -26,21 +25,16 @@ from dynamo.runtime import Client, Component, DistributedRuntime
from
..args
import
Config
from
..args
import
Config
from
..handlers
import
BaseWorkerHandler
,
build_sampling_params
from
..handlers
import
BaseWorkerHandler
,
build_sampling_params
from
..multimodal_utils
import
(
from
..multimodal_utils
import
(
MultiModalGroup
,
MyRequestOutput
,
MyRequestOutput
,
PatchedTokensPrompt
,
PatchedTokensPrompt
,
vLLMMultimodalRequest
,
vLLMMultimodalRequest
,
)
)
from
..multimodal_utils.model
import
is_qwen_vl_model
from
..multimodal_utils.model
import
is_qwen_vl_model
from
..multimodal_utils.prefill_worker_utils
import
(
from
..multimodal_utils.prefill_worker_utils
import
load_multimodal_embeddings
IMAGE_URL_KEY
,
accumulate_embeddings
,
fetch_embeddings_from_encode_workers
,
load_embeddings
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
IMAGE_URL_KEY
=
"image_url"
TRANSFER_LOCAL
=
int
(
os
.
getenv
(
"TRANSFER_LOCAL"
,
1
))
TRANSFER_LOCAL
=
int
(
os
.
getenv
(
"TRANSFER_LOCAL"
,
1
))
...
@@ -96,8 +90,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -96,8 +90,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
else
:
else
:
self
.
EMBEDDINGS_DTYPE
=
torch
.
float16
self
.
EMBEDDINGS_DTYPE
=
torch
.
float16
self
.
EMBEDDINGS_DEVICE
=
"cpu"
# Create and initialize a dynamo connector for this worker.
# Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently.
# We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init
# Note: This is synchronous initialization, async initialization happens in async_init
...
@@ -120,19 +112,18 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -120,19 +112,18 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self
.
_connector
=
connect
.
Connector
()
self
.
_connector
=
connect
.
Connector
()
logger
.
info
(
"Multimodal PD Worker async initialization completed."
)
logger
.
info
(
"Multimodal PD Worker async initialization completed."
)
async
def
_build_request_from
_frontend
(
def
_parse
_frontend
_request
(
self
,
raw_request
:
dict
self
,
raw_request
:
dict
)
->
vLLMMultimodalRequest
:
)
->
tuple
[
vLLMMultimodalRequest
,
list
[
str
]]
:
"""
Convert
a raw frontend dict into a vLLMMultimodalRequest.
"""
Parse
a raw frontend dict into a vLLMMultimodalRequest
and image URLs
.
W
he
n the PD worker is the direc
t frontend end
point (no separate
T
he
Rus
t frontend
s
end
s a dict with ``token_ids`` and
processor), the Rust frontend sends a dict representation of PreprocessedRequest.
``multi_modal_data`` (containing image URLs). This method extracts
This method ex
tr
a
ct
s image URLs, routes them to encode workers if available,
those fields into a s
tr
u
ct
ured request. No I/O is performed here;
and assembles the standard request object that the rest of ``generate()`` expects
.
embedding fetching is handled separately by ``_load_multimodal_data``
.
"""
"""
request_id
=
str
(
uuid
.
uuid4
().
hex
)
request_id
=
str
(
uuid
.
uuid4
().
hex
)
# Extract image URLs from the raw frontend dict
image_urls
:
list
[
str
]
=
[]
image_urls
:
list
[
str
]
=
[]
mm_data
=
raw_request
.
get
(
"multi_modal_data"
)
mm_data
=
raw_request
.
get
(
"multi_modal_data"
)
if
mm_data
is
not
None
:
if
mm_data
is
not
None
:
...
@@ -140,88 +131,43 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -140,88 +131,43 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if
isinstance
(
item
,
dict
)
and
"Url"
in
item
:
if
isinstance
(
item
,
dict
)
and
"Url"
in
item
:
image_urls
.
append
(
item
[
"Url"
])
image_urls
.
append
(
item
[
"Url"
])
multimodal_groups
:
list
[
MultiModalGroup
]
=
[]
if
self
.
encode_worker_client
and
image_urls
:
multimodal_groups
=
await
fetch_embeddings_from_encode_workers
(
self
.
encode_worker_client
,
image_urls
,
request_id
,
)
sampling_params
=
build_sampling_params
(
sampling_params
=
build_sampling_params
(
raw_request
,
self
.
default_sampling_params
raw_request
,
self
.
default_sampling_params
)
)
re
turn
vLLMMultimodalRequest
(
re
quest
=
vLLMMultimodalRequest
(
engine_prompt
=
PatchedTokensPrompt
(
engine_prompt
=
PatchedTokensPrompt
(
prompt_token_ids
=
raw_request
[
"token_ids"
]
prompt_token_ids
=
raw_request
[
"token_ids"
]
),
),
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
request_id
=
request_id
,
model
=
raw_request
.
get
(
"model"
),
model
=
raw_request
.
get
(
"model"
),
multimodal_inputs
=
multimodal_groups
,
)
)
# ── Request parsing ────────────────────────────────────────────────
return
request
,
image_urls
async
def
_parse_request
(
self
,
request
)
->
vLLMMultimodalRequest
:
"""Normalize any incoming format into a validated vLLMMultimodalRequest.
Handles three input shapes:
1. Raw frontend dict (has ``token_ids`` + ``multi_modal_data``)
2. JSON string (from encode worker or other serializers)
3. Plain dict (Pydantic-compatible mapping)
"""
if
isinstance
(
request
,
dict
)
and
"token_ids"
in
request
:
return
await
self
.
_build_request_from_frontend
(
request
)
if
type
(
request
)
is
vLLMMultimodalRequest
:
return
request
if
type
(
request
)
is
str
:
return
vLLMMultimodalRequest
.
model_validate_json
(
request
)
return
vLLMMultimodalRequest
.
model_validate
(
request
)
# ── Multimodal data loading ──────────────────────────────────────
# ── Multimodal data loading ──────────────────────────────────────
async
def
_load_multimodal_data
(
async
def
_load_multimodal_data
(
self
,
request
:
vLLMMultimodalRequest
self
,
image_urls
:
list
[
str
],
request_id
:
str
)
->
tuple
[
dict
[
str
,
Any
],
list
[
int
]]:
)
->
dict
[
str
,
Any
]:
"""Load pre-computed embeddings into an engine-ready dict.
"""Fetch embeddings from encode workers and load into an engine-ready dict.
Each ``MultiModalGroup`` carries embeddings from encode workers,
loaded via NIXL RDMA or local safetensors.
No-op when --route-to-encoder is not set.
Returns an empty dict when no encode worker is configured or no images
are present.
"""
"""
multimodal_inputs
:
list
[
MultiModalGroup
]
=
request
.
multimodal_inputs
or
[]
if
not
self
.
encode_worker_client
or
not
image_urls
:
multi_modal_data
:
dict
[
str
,
Any
]
=
defaultdict
(
list
)
return
defaultdict
(
list
)
task_lists
=
[
return
await
load_multimodal_embeddings
(
asyncio
.
create_task
(
self
.
encode_worker_client
,
# type: ignore[arg-type]
load_embeddings
(
image_urls
,
mi
,
request_id
,
self
.
EMBEDDINGS_DTYPE
,
self
.
embedding_receiver
,
self
.
EMBEDDINGS_DEVICE
,
model
=
self
.
config
.
model
,
self
.
embedding_receiver
,
embeddings_dtype
=
self
.
EMBEDDINGS_DTYPE
,
)
cache
=
self
.
embedding_cache_manager
,
)
)
for
mi
in
multimodal_inputs
]
receiver_tensor_ids
:
list
[
int
]
=
[]
for
task
,
mi
in
zip
(
task_lists
,
multimodal_inputs
):
tensor_id
,
embeddings
=
await
task
receiver_tensor_ids
.
append
(
tensor_id
)
accumulate_embeddings
(
multi_modal_data
,
self
.
config
.
model
,
self
.
EMBEDDINGS_DTYPE
,
embeddings
,
mi
.
image_grid_thw
,
)
return
multi_modal_data
,
receiver_tensor_ids
# ── Request metadata finalization ────────────────────────────────
# ── Request metadata finalization ────────────────────────────────
...
@@ -230,14 +176,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -230,14 +176,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request
:
vLLMMultimodalRequest
,
request
:
vLLMMultimodalRequest
,
multi_modal_data
:
dict
[
str
,
Any
],
multi_modal_data
:
dict
[
str
,
Any
],
)
->
None
:
)
->
None
:
"""Attach model-specific metadata
and strip heavy fields from request
.
"""Attach model-specific metadata
to the request for the decode worker
.
For Qwen VL (mRoPE) models, captures image grid dimensions and
For Qwen VL (mRoPE) models, captures image grid dimensions and
embedding shapes so the decode worker can reconstruct
embedding shapes so the decode worker can reconstruct
``multi_modal_data`` consistently for multiple images.
``multi_modal_data`` consistently for multiple images.
Also clears ``multimodal_inputs`` — the raw embeddings / URLs are no
longer needed once ``multi_modal_data`` is built.
"""
"""
if
is_qwen_vl_model
(
self
.
config
.
model
)
and
isinstance
(
if
is_qwen_vl_model
(
self
.
config
.
model
)
and
isinstance
(
multi_modal_data
.
get
(
"image"
),
dict
multi_modal_data
.
get
(
"image"
),
dict
...
@@ -254,11 +197,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -254,11 +197,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
request
.
embeddings_shape
=
list
(
image_embeds
.
shape
)
request
.
embeddings_shape
=
list
(
image_embeds
.
shape
)
# Use empty list instead of None to satisfy Pydantic validation
logger
.
debug
(
f
"Prepared multimodal data size:
{
len
(
multi_modal_data
[
'image'
])
}
"
)
# on decode worker after vllm upgrade.
request
.
multimodal_inputs
=
[]
logger
.
info
(
f
"Prepared multimodal data size:
{
len
(
multi_modal_data
[
'image'
])
}
"
)
logger
.
debug
(
"Multimodal data keys: %s"
,
list
(
multi_modal_data
.
keys
()))
logger
.
debug
(
"Multimodal data keys: %s"
,
list
(
multi_modal_data
.
keys
()))
# ── Response serialization ───────────────────────────────────────
# ── Response serialization ───────────────────────────────────────
...
@@ -318,7 +257,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -318,7 +257,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self
,
self
,
request
:
vLLMMultimodalRequest
,
request
:
vLLMMultimodalRequest
,
multi_modal_data
:
dict
[
str
,
Any
],
multi_modal_data
:
dict
[
str
,
Any
],
received_tensor_ids
:
list
[
int
],
):
):
"""Run prefill and decode on this worker (aggregated mode)."""
"""Run prefill and decode on this worker (aggregated mode)."""
lora_request
=
self
.
_resolve_lora_request
(
request
.
model
)
lora_request
=
self
.
_resolve_lora_request
(
request
.
model
)
...
@@ -332,9 +270,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -332,9 +270,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
for
tensor_id
in
received_tensor_ids
:
self
.
embedding_receiver
.
release_tensor
(
tensor_id
)
num_output_tokens_so_far
=
0
num_output_tokens_so_far
=
0
async
for
response
in
gen
:
async
for
response
in
gen
:
logger
.
debug
(
f
"Response kv_transfer_params:
{
response
.
kv_transfer_params
}
"
)
logger
.
debug
(
f
"Response kv_transfer_params:
{
response
.
kv_transfer_params
}
"
)
...
@@ -351,7 +286,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -351,7 +286,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self
,
self
,
request
:
vLLMMultimodalRequest
,
request
:
vLLMMultimodalRequest
,
multi_modal_data
:
dict
[
str
,
Any
],
multi_modal_data
:
dict
[
str
,
Any
],
received_tensor_ids
:
list
[
int
],
):
):
"""Prefill locally, then forward to a remote decode worker."""
"""Prefill locally, then forward to a remote decode worker."""
# Prepare prefill-only request
# Prepare prefill-only request
...
@@ -374,9 +308,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -374,9 +308,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
for
tensor_id
in
received_tensor_ids
:
self
.
embedding_receiver
.
release_tensor
(
tensor_id
)
# Drain prefill generator (max_tokens=1, expect a single response)
# Drain prefill generator (max_tokens=1, expect a single response)
async
for
prefill_response
in
gen
:
async
for
prefill_response
in
gen
:
pass
pass
...
@@ -415,6 +346,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -415,6 +346,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
f
"Forwarding disaggregated decode with LoRA '
{
request
.
model
}
' "
f
"Forwarding disaggregated decode with LoRA '
{
request
.
model
}
' "
f
"— ensure the same adapter is loaded on the decode worker."
f
"— ensure the same adapter is loaded on the decode worker."
)
)
async
for
(
async
for
(
decode_response
decode_response
)
in
await
self
.
decode_worker_client
.
round_robin
(
# type: ignore[union-attr]
)
in
await
self
.
decode_worker_client
.
round_robin
(
# type: ignore[union-attr]
...
@@ -425,30 +357,19 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -425,30 +357,19 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# ── Public entry point ───────────────────────────────────────────
# ── Public entry point ───────────────────────────────────────────
async
def
generate
(
self
,
request
,
context
):
async
def
generate
(
self
,
raw_
request
:
dict
,
context
):
"""Parse the request, load multimodal data, and run inference."""
"""Parse the request, load multimodal data, and run inference."""
logger
.
debug
(
f
"Got raw request:
{
request
}
"
)
request
,
image_urls
=
self
.
_parse_frontend_request
(
raw_request
)
request
=
await
self
.
_parse_request
(
request
)
logger
.
debug
(
f
"Received PD request: {{ id:
{
request
.
request_id
}
}}."
)
logger
.
debug
(
f
"Received PD request: {{ id:
{
request
.
request_id
}
}}."
)
multi_modal_data
,
received_tensor_ids
=
await
self
.
_load_multimodal_data
(
multi_modal_data
=
await
self
.
_load_multimodal_data
(
request
image_urls
,
request
.
request_id
)
)
self
.
_finalize_request_metadata
(
request
,
multi_modal_data
)
self
.
_finalize_request_metadata
(
request
,
multi_modal_data
)
logger
.
info
(
f
"Prepared multimodal data size:
{
len
(
multi_modal_data
.
get
(
'image'
,
[]))
}
"
)
logger
.
debug
(
f
"
{
multi_modal_data
}
"
)
if
self
.
enable_disagg
and
self
.
decode_worker_client
:
if
self
.
enable_disagg
and
self
.
decode_worker_client
:
async
for
chunk
in
self
.
_generate_disagg
(
async
for
chunk
in
self
.
_generate_disagg
(
request
,
multi_modal_data
):
request
,
multi_modal_data
,
received_tensor_ids
):
yield
chunk
yield
chunk
else
:
else
:
async
for
chunk
in
self
.
_generate_agg
(
async
for
chunk
in
self
.
_generate_agg
(
request
,
multi_modal_data
):
request
,
multi_modal_data
,
received_tensor_ids
):
yield
chunk
yield
chunk
components/src/dynamo/vllm/multimodal_utils/__init__.py
View file @
c82fe888
...
@@ -19,11 +19,7 @@ from dynamo.vllm.multimodal_utils.model import (
...
@@ -19,11 +19,7 @@ from dynamo.vllm.multimodal_utils.model import (
construct_mm_data
,
construct_mm_data
,
load_vision_model
,
load_vision_model
,
)
)
from
dynamo.vllm.multimodal_utils.prefill_worker_utils
import
(
from
dynamo.vllm.multimodal_utils.prefill_worker_utils
import
load_multimodal_embeddings
accumulate_embeddings
,
fetch_embeddings_from_encode_workers
,
load_embeddings
,
)
from
dynamo.vllm.multimodal_utils.protocol
import
(
from
dynamo.vllm.multimodal_utils.protocol
import
(
MultiModalGroup
,
MultiModalGroup
,
MultiModalInput
,
MultiModalInput
,
...
@@ -52,7 +48,5 @@ __all__ = [
...
@@ -52,7 +48,5 @@ __all__ = [
"MultiModalRequest"
,
"MultiModalRequest"
,
"MyRequestOutput"
,
"MyRequestOutput"
,
"vLLMMultimodalRequest"
,
"vLLMMultimodalRequest"
,
"accumulate_embeddings"
,
"load_multimodal_embeddings"
,
"fetch_embeddings_from_encode_workers"
,
"load_embeddings"
,
]
]
components/src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
View file @
c82fe888
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
logging
import
logging
import
os
import
os
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
,
Dict
,
List
import
torch
import
torch
from
vllm.sampling_params
import
SamplingParams
as
VllmSamplingParams
from
vllm.sampling_params
import
SamplingParams
as
VllmSamplingParams
from
dynamo.common.multimodal.embedding_transfer
import
AbstractEmbeddingReceiver
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
CachedEmbedding
,
MultimodalEmbeddingCacheManager
,
)
from
dynamo.common.multimodal.embedding_transfer
import
(
AbstractEmbeddingReceiver
,
LocalEmbeddingReceiver
,
)
from
dynamo.runtime
import
Client
from
dynamo.runtime
import
Client
from
.encode_utils
import
get_embedding_hash
from
.model
import
construct_mm_data
from
.model
import
construct_mm_data
from
.protocol
import
(
from
.protocol
import
(
MultiModalGroup
,
MultiModalGroup
,
...
@@ -21,39 +31,37 @@ from .protocol import (
...
@@ -21,39 +31,37 @@ from .protocol import (
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
IMAGE_URL_KEY
=
"image_url"
SPLIT_ENCODE
=
int
(
os
.
getenv
(
"DYN_SPLIT_ENCODE"
,
1
))
VIDEO_URL_KEY
=
"video_url"
# Whether to split the multimodal items into smaller batches for encoding. This can help if multimodal items can be speed up
# by separately encodeded with multiple workers.
# Need to experiment with this setting to see if it brings benefits when concurrency > encoder count.
SPLIT_ENCODE
=
int
(
os
.
getenv
(
"SPLIT_ENCODE"
,
1
))
# ── Internal helpers (all underscore-prefixed) ───────────────────────
async
def
load_embeddings
(
mi
:
MultiModalGroup
,
class
_PendingRelease
:
_embeddings_dtype
:
torch
.
dtype
,
"""Tracks NIXL tensor buffers that should be released after consumption.
_embeddings_device
:
str
,
receiver
:
AbstractEmbeddingReceiver
,
For NIXL receivers, embeddings are views into pre-allocated reusable
)
->
tuple
[
int
,
torch
.
Tensor
]:
buffers. Instead of cloning each embedding eagerly, we defer the
"""Load pre-computed embedding tensor via local safetensors or NIXL RDMA.
release until the caller has consumed the tensors (e.g. via
``_accumulate_embeddings`` which copies data through ``torch.cat``).
Args:
mi: A single MultiModalGroup whose ``serialized_request`` field
contains either a local file path or NIXL RDMA metadata.
embeddings_dtype: Torch dtype for the tensor (used for RDMA path).
embeddings_device: Device string for the tensor (used for RDMA path).
receiver: AbstractEmbeddingReceiver for tensor reads.
Returns:
A tuple of (tensor_id, embeddings), where tensor_id is an integer identifier for the loaded tensor (used for later release),
and the embeddings tensor loaded into CPU memory.
"""
"""
tensor_id
,
embeddings
=
await
receiver
.
receive_embeddings
(
mi
.
serialized_request
)
return
tensor_id
,
embeddings
__slots__
=
(
"_receiver"
,
"_tensor_ids"
)
def
__init__
(
self
,
receiver
:
AbstractEmbeddingReceiver
):
self
.
_receiver
=
receiver
self
.
_tensor_ids
:
List
[
int
]
=
[]
def
track
(
self
,
tensor_id
:
int
)
->
None
:
self
.
_tensor_ids
.
append
(
tensor_id
)
def
accumulate_embeddings
(
def
release_all
(
self
)
->
None
:
for
tid
in
self
.
_tensor_ids
:
self
.
_receiver
.
release_tensor
(
tid
)
self
.
_tensor_ids
.
clear
()
def
_accumulate_embeddings
(
multi_modal_data
:
Dict
[
str
,
Any
],
multi_modal_data
:
Dict
[
str
,
Any
],
model
:
str
,
model
:
str
,
embeddings_dtype
:
torch
.
dtype
,
embeddings_dtype
:
torch
.
dtype
,
...
@@ -113,16 +121,32 @@ def accumulate_embeddings(
...
@@ -113,16 +121,32 @@ def accumulate_embeddings(
)
)
async
def
fetch_embeddings_from_encode_workers
(
def
_ensure_owned_tensors
(
multi_modal_data
:
Dict
[
str
,
Any
])
->
None
:
"""Clone tensor views so NIXL buffers can be safely released.
Only needed for single-image; multi-image goes through torch.cat
which already produces owned tensors.
"""
img
=
multi_modal_data
.
get
(
"image"
)
if
isinstance
(
img
,
dict
):
for
k
,
v
in
img
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
img
[
k
]
=
v
.
clone
()
elif
isinstance
(
img
,
torch
.
Tensor
):
multi_modal_data
[
"image"
]
=
img
.
clone
()
async
def
_fetch_from_encode_workers
(
encode_worker_client
:
Client
,
encode_worker_client
:
Client
,
image_urls
:
List
[
str
],
image_urls
:
List
[
str
],
request_id
:
str
,
request_id
:
str
,
)
->
List
[
MultiModalGroup
]:
receiver
:
AbstractEmbeddingReceiver
,
"""Fan out image URLs to encode workers and collect embedding results.
)
->
tuple
[
List
[
MultiModalGroup
],
_PendingRelease
|
None
]:
"""Fan out image URLs to encode workers, load embeddings, and return ready groups.
Splits image URLs into batches based on available encode worker cou
nt
,
For NIXL receivers the returned embeddings are zero-copy views i
nt
o
dispatches via round-robin, and collects the resulting MultiModalGroups
pre-allocated buffers. The returned ``_PendingRelease`` must be
containing pre-computed embeddings
.
released after the tensors have been consumed
.
"""
"""
encode_worker_count
=
len
(
encode_worker_client
.
instance_ids
())
encode_worker_count
=
len
(
encode_worker_client
.
instance_ids
())
if
encode_worker_count
==
0
:
if
encode_worker_count
==
0
:
...
@@ -156,7 +180,6 @@ async def fetch_embeddings_from_encode_workers(
...
@@ -156,7 +180,6 @@ async def fetch_embeddings_from_encode_workers(
)
)
batch
=
[]
batch
=
[]
# Flush remaining
if
batch
:
if
batch
:
encode_request
.
multimodal_inputs
=
batch
encode_request
.
multimodal_inputs
=
batch
payload
=
encode_request
.
model_dump_json
()
payload
=
encode_request
.
model_dump_json
()
...
@@ -164,7 +187,6 @@ async def fetch_embeddings_from_encode_workers(
...
@@ -164,7 +187,6 @@ async def fetch_embeddings_from_encode_workers(
await
encode_worker_client
.
round_robin
(
payload
)
# type: ignore[arg-type]
await
encode_worker_client
.
round_robin
(
payload
)
# type: ignore[arg-type]
)
)
# Collect results
multimodal_groups
:
List
[
MultiModalGroup
]
=
[]
multimodal_groups
:
List
[
MultiModalGroup
]
=
[]
for
stream
in
encode_response_streams
:
for
stream
in
encode_response_streams
:
async
for
response
in
stream
:
async
for
response
in
stream
:
...
@@ -173,4 +195,135 @@ async def fetch_embeddings_from_encode_workers(
...
@@ -173,4 +195,135 @@ async def fetch_embeddings_from_encode_workers(
if
output
.
multimodal_inputs
:
if
output
.
multimodal_inputs
:
multimodal_groups
.
extend
(
output
.
multimodal_inputs
)
multimodal_groups
.
extend
(
output
.
multimodal_inputs
)
return
multimodal_groups
tasks
=
[
asyncio
.
create_task
(
receiver
.
receive_embeddings
(
group
.
serialized_request
))
for
group
in
multimodal_groups
]
loaded
=
await
asyncio
.
gather
(
*
tasks
)
is_local
=
isinstance
(
receiver
,
LocalEmbeddingReceiver
)
pending
:
_PendingRelease
|
None
=
None
if
is_local
else
_PendingRelease
(
receiver
)
for
group
,
(
tensor_id
,
embedding
)
in
zip
(
multimodal_groups
,
loaded
,
strict
=
True
):
group
.
loaded_embedding
=
embedding
if
pending
is
not
None
:
pending
.
track
(
tensor_id
)
return
multimodal_groups
,
pending
async
def
_fetch_embeddings
(
encode_worker_client
:
Client
,
image_urls
:
list
[
str
],
request_id
:
str
,
receiver
:
AbstractEmbeddingReceiver
,
cache
:
MultimodalEmbeddingCacheManager
|
None
=
None
,
)
->
tuple
[
list
[
MultiModalGroup
],
_PendingRelease
|
None
]:
"""Fetch multimodal embeddings with transparent cache-through.
Pipeline: check_cache → fetch misses from encode workers → update_cache.
When *cache* is ``None`` the cache steps are no-ops and all URLs go
straight to the encode workers.
For NIXL receivers the returned embeddings are zero-copy views. The
returned ``_PendingRelease`` must be released after consuming the
tensors.
"""
results
:
list
[
MultiModalGroup
|
None
]
=
[
None
]
*
len
(
image_urls
)
to_fetch
:
list
[
tuple
[
int
,
str
,
str
|
None
]]
=
[]
# ── 1. Check cache (no-op when cache is None) ────────────────────
for
idx
,
url
in
enumerate
(
image_urls
):
if
cache
is
not
None
:
key
=
get_embedding_hash
(
url
)
cached
=
cache
.
get
(
key
)
if
cached
is
not
None
:
logger
.
debug
(
f
"[
{
request_id
}
] Cache hit for URL index
{
idx
}
"
)
results
[
idx
]
=
MultiModalGroup
(
loaded_embedding
=
cached
.
tensor
,
image_grid_thw
=
cached
.
image_grid_thw
,
)
continue
else
:
key
=
None
to_fetch
.
append
((
idx
,
url
,
key
))
# ── 2. Fetch uncached from encode workers ────────────────────────
pending
:
_PendingRelease
|
None
=
None
if
to_fetch
:
if
cache
is
not
None
:
logger
.
info
(
f
"[
{
request_id
}
] Cache miss for
{
len
(
to_fetch
)
}
/
{
len
(
image_urls
)
}
URLs, "
"fetching from encode workers"
)
miss_urls
=
[
url
for
_
,
url
,
_
in
to_fetch
]
groups
,
pending
=
await
_fetch_from_encode_workers
(
encode_worker_client
,
miss_urls
,
request_id
,
receiver
,
)
# ── 3. Update cache (no-op when cache is None) ──────────────
for
(
idx
,
_url
,
key
),
group
in
zip
(
to_fetch
,
groups
,
strict
=
True
):
if
cache
is
not
None
and
key
is
not
None
:
cache
.
set
(
key
,
CachedEmbedding
(
tensor
=
group
.
loaded_embedding
.
clone
(),
image_grid_thw
=
group
.
image_grid_thw
,
),
)
results
[
idx
]
=
group
else
:
logger
.
info
(
f
"[
{
request_id
}
] All
{
len
(
image_urls
)
}
URLs served from cache"
)
return
[
r
for
r
in
results
if
r
is
not
None
],
pending
# ── Public API (single entry point) ─────────────────────────────────
async
def
load_multimodal_embeddings
(
encode_worker_client
:
Client
,
image_urls
:
list
[
str
],
request_id
:
str
,
receiver
:
AbstractEmbeddingReceiver
,
*
,
model
:
str
,
embeddings_dtype
:
torch
.
dtype
,
cache
:
MultimodalEmbeddingCacheManager
|
None
=
None
,
)
->
Dict
[
str
,
Any
]:
"""Fetch embeddings and build engine-ready ``multi_modal_data``.
Full pipeline:
cache check → remote fetch → cache update → accumulate → release NIXL buffers.
Returns a dict suitable for passing to ``TokensPrompt(multi_modal_data=...)``.
"""
groups
,
pending
=
await
_fetch_embeddings
(
encode_worker_client
,
image_urls
,
request_id
,
receiver
,
cache
=
cache
,
)
multi_modal_data
:
Dict
[
str
,
Any
]
=
defaultdict
(
list
)
for
group
in
groups
:
_accumulate_embeddings
(
multi_modal_data
,
model
,
embeddings_dtype
,
group
.
loaded_embedding
,
group
.
image_grid_thw
,
)
if
pending
is
not
None
:
# Multi-image: torch.cat in _accumulate_embeddings already created
# owned tensors. Single-image: the data is still a view into the
# NIXL buffer, so we must clone before releasing.
if
len
(
groups
)
==
1
:
_ensure_owned_tensors
(
multi_modal_data
)
pending
.
release_all
()
return
multi_modal_data
components/src/dynamo/vllm/multimodal_utils/protocol.py
View file @
c82fe888
...
@@ -18,6 +18,7 @@ import json
...
@@ -18,6 +18,7 @@ import json
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Tuple
,
Union
import
msgspec
import
msgspec
import
torch
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
field_serializer
,
field_validator
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
field_serializer
,
field_validator
from
pydantic_core
import
core_schema
from
pydantic_core
import
core_schema
from
typing_extensions
import
NotRequired
from
typing_extensions
import
NotRequired
...
@@ -171,6 +172,7 @@ class MultiModalGroup(BaseModel):
...
@@ -171,6 +172,7 @@ class MultiModalGroup(BaseModel):
Union
[
Tuple
[
int
,
int
,
int
],
Tuple
[
int
,
int
,
int
,
int
]]
Union
[
Tuple
[
int
,
int
,
int
],
Tuple
[
int
,
int
,
int
,
int
]]
]
=
None
]
=
None
serialized_request
:
Optional
[
TransferRequest
]
=
None
serialized_request
:
Optional
[
TransferRequest
]
=
None
loaded_embedding
:
Optional
[
torch
.
Tensor
]
=
Field
(
default
=
None
,
exclude
=
True
)
class
vLLMMultimodalRequest
(
vLLMGenerateRequest
):
class
vLLMMultimodalRequest
(
vLLMGenerateRequest
):
...
...
components/src/dynamo/vllm/tests/multimodal_handlers/test_vllm_multimodal_pd_worker_handler.py
View file @
c82fe888
...
@@ -4,17 +4,17 @@
...
@@ -4,17 +4,17 @@
"""Unit tests for MultimodalPDWorkerHandler."""
"""Unit tests for MultimodalPDWorkerHandler."""
import
json
import
json
from
collections
import
defaultdict
from
unittest.mock
import
AsyncMock
,
MagicMock
,
patch
from
unittest.mock
import
AsyncMock
,
MagicMock
,
patch
import
pytest
import
pytest
import
torch
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
MultimodalEmbeddingCacheManager
,
MultimodalEmbeddingCacheManager
,
)
)
from
dynamo.vllm.multimodal_handlers
import
multimodal_pd_worker_handler
as
mod
from
dynamo.vllm.multimodal_handlers
import
multimodal_pd_worker_handler
as
mod
from
dynamo.vllm.multimodal_utils.protocol
import
(
from
dynamo.vllm.multimodal_utils.protocol
import
(
MultiModalGroup
,
MultiModalInput
,
MyRequestOutput
,
MyRequestOutput
,
PatchedTokensPrompt
,
PatchedTokensPrompt
,
vLLMMultimodalRequest
,
vLLMMultimodalRequest
,
...
@@ -128,26 +128,98 @@ class TestInit:
...
@@ -128,26 +128,98 @@ class TestInit:
assert
handler
.
embedding_cache_manager
.
_capacity_bytes
==
expected_bytes
assert
handler
.
embedding_cache_manager
.
_capacity_bytes
==
expected_bytes
class
TestBuildRequestFromFrontend
:
class
TestParseFrontendRequest
:
def
test_extracts_token_ids_and_sampling_params
(
self
):
"""Parses token_ids and sampling_params from raw frontend dict."""
handler
=
_make_handler
()
handler
.
default_sampling_params
=
{}
raw
=
_make_raw_frontend_request
()
request
,
image_urls
=
handler
.
_parse_frontend_request
(
raw
)
assert
request
.
engine_prompt
[
"prompt_token_ids"
]
==
[
1
,
2
,
3
]
assert
image_urls
==
[]
def
test_extracts_image_urls
(
self
):
"""Extracts image URLs from multi_modal_data."""
handler
=
_make_handler
()
handler
.
default_sampling_params
=
{}
raw
=
_make_raw_frontend_request
(
image_urls
=
[
"http://a.png"
,
"http://b.png"
])
request
,
image_urls
=
handler
.
_parse_frontend_request
(
raw
)
assert
image_urls
==
[
"http://a.png"
,
"http://b.png"
]
class
TestLoadMultimodalData
:
@
pytest
.
mark
.
asyncio
async
def
test_no_encode_client_returns_empty
(
self
):
"""Without encode client -> returns empty dict."""
handler
=
_make_handler
(
encode_worker_client
=
None
)
mm_data
=
await
handler
.
_load_multimodal_data
([
"http://img.png"
],
"req-1"
)
assert
len
(
mm_data
)
==
0
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_with_encode_worker_calls_fetch
(
self
):
async
def
test_no_images_returns_empty
(
self
):
"""With encode client -> delegates to fetch_embeddings_from_encode_workers."""
"""With encode client but no images -> returns empty dict."""
handler
=
_make_handler
(
encode_worker_client
=
MagicMock
())
mm_data
=
await
handler
.
_load_multimodal_data
([],
"req-1"
)
assert
len
(
mm_data
)
==
0
@
pytest
.
mark
.
asyncio
async
def
test_delegates_to_load_multimodal_embeddings
(
self
):
"""With encode client -> delegates to load_multimodal_embeddings."""
mock_client
=
MagicMock
()
mock_client
=
MagicMock
()
handler
=
_make_handler
(
encode_worker_client
=
mock_client
)
handler
=
_make_handler
(
encode_worker_client
=
mock_client
)
handler
.
default_sampling_params
=
{}
fake_group
=
MultiModalGroup
(
multimodal_input
=
MultiModalInput
())
fake_mm_data
=
defaultdict
(
list
,
{
"image"
:
torch
.
randn
(
1
,
10
)})
with
patch
.
object
(
mod
,
"load_multimodal_embeddings"
,
new_callable
=
AsyncMock
,
return_value
=
fake_mm_data
,
)
as
mock_load
:
result
=
await
handler
.
_load_multimodal_data
([
"http://img.png"
],
"req-1"
)
mock_load
.
assert_awaited_once
()
assert
result
is
fake_mm_data
@
pytest
.
mark
.
asyncio
async
def
test_passes_cache_to_load_multimodal_embeddings
(
self
):
"""With cache enabled -> passes cache manager kwarg."""
mock_client
=
MagicMock
()
config
=
_make_config
(
multimodal_embedding_cache_capacity_gb
=
1.0
)
handler
=
_make_handler
(
config
=
config
,
encode_worker_client
=
mock_client
)
with
patch
.
object
(
with
patch
.
object
(
mod
,
mod
,
"
fetch_embeddings_from_encode_worker
s"
,
"
load_multimodal_embedding
s"
,
new_callable
=
AsyncMock
,
new_callable
=
AsyncMock
,
return_value
=
[
fake_group
],
return_value
=
defaultdict
(
list
),
)
as
mock_fetch
:
)
as
mock_load
:
raw
=
_make_raw_frontend_request
(
image_urls
=
[
"http://img.png"
])
await
handler
.
_load_multimodal_data
([
"http://img.png"
],
"req-1"
)
result
=
await
handler
.
_build_request_from_frontend
(
raw
)
mock_fetch
.
assert_awaited_once
()
mock_load
.
assert_awaited_once
()
assert
result
.
multimodal_inputs
==
[
fake_group
]
assert
mock_load
.
call_args
.
kwargs
[
"cache"
]
is
handler
.
embedding_cache_manager
@
pytest
.
mark
.
asyncio
async
def
test_passes_model_and_dtype
(
self
):
"""Model name and embeddings dtype are forwarded."""
mock_client
=
MagicMock
()
handler
=
_make_handler
(
encode_worker_client
=
mock_client
)
with
patch
.
object
(
mod
,
"load_multimodal_embeddings"
,
new_callable
=
AsyncMock
,
return_value
=
defaultdict
(
list
),
)
as
mock_load
:
await
handler
.
_load_multimodal_data
([
"http://img.png"
],
"req-1"
)
assert
mock_load
.
call_args
.
kwargs
[
"model"
]
==
handler
.
config
.
model
assert
(
mock_load
.
call_args
.
kwargs
[
"embeddings_dtype"
]
==
handler
.
EMBEDDINGS_DTYPE
)
class
TestGenerateAgg
:
class
TestGenerateAgg
:
...
@@ -158,7 +230,6 @@ class TestGenerateAgg:
...
@@ -158,7 +230,6 @@ class TestGenerateAgg:
request
=
_make_vllm_request
()
request
=
_make_vllm_request
()
engine_resp
=
_make_engine_response
()
engine_resp
=
_make_engine_response
()
# Add a proper output so we exercise the happy path
output
=
MagicMock
()
output
=
MagicMock
()
output
.
token_ids
=
[
10
,
11
]
output
.
token_ids
=
[
10
,
11
]
output
.
finish_reason
=
"stop"
output
.
finish_reason
=
"stop"
...
@@ -172,7 +243,7 @@ class TestGenerateAgg:
...
@@ -172,7 +243,7 @@ class TestGenerateAgg:
handler
.
engine_client
.
generate
=
fake_generate
handler
.
engine_client
.
generate
=
fake_generate
chunks
=
[]
chunks
=
[]
async
for
chunk
in
handler
.
_generate_agg
(
request
,
{
"image"
:
[]}
,
[]
):
async
for
chunk
in
handler
.
_generate_agg
(
request
,
{
"image"
:
[]}):
chunks
.
append
(
chunk
)
chunks
.
append
(
chunk
)
assert
len
(
chunks
)
==
1
assert
len
(
chunks
)
==
1
...
@@ -189,7 +260,6 @@ class TestGenerateDisagg:
...
@@ -189,7 +260,6 @@ class TestGenerateDisagg:
handler
=
_make_handler
(
config
=
config
,
decode_worker_client
=
decode_client
)
handler
=
_make_handler
(
config
=
config
,
decode_worker_client
=
decode_client
)
handler
.
engine_client
=
MagicMock
()
handler
.
engine_client
=
MagicMock
()
# Mock prefill engine response
prefill_resp
=
_make_engine_response
()
prefill_resp
=
_make_engine_response
()
prefill_resp
.
kv_transfer_params
=
{
"block_ids"
:
[
0
,
1
]}
prefill_resp
.
kv_transfer_params
=
{
"block_ids"
:
[
0
,
1
]}
...
@@ -198,7 +268,6 @@ class TestGenerateDisagg:
...
@@ -198,7 +268,6 @@ class TestGenerateDisagg:
handler
.
engine_client
.
generate
=
fake_generate
handler
.
engine_client
.
generate
=
fake_generate
# Mock decode worker response
decode_output
=
MyRequestOutput
(
decode_output
=
MyRequestOutput
(
request_id
=
"req-1"
,
request_id
=
"req-1"
,
prompt
=
"test"
,
prompt
=
"test"
,
...
@@ -220,7 +289,7 @@ class TestGenerateDisagg:
...
@@ -220,7 +289,7 @@ class TestGenerateDisagg:
request
=
_make_vllm_request
()
request
=
_make_vllm_request
()
chunks
=
[]
chunks
=
[]
async
for
chunk
in
handler
.
_generate_disagg
(
request
,
{
"image"
:
[]}
,
[]
):
async
for
chunk
in
handler
.
_generate_disagg
(
request
,
{
"image"
:
[]}):
chunks
.
append
(
chunk
)
chunks
.
append
(
chunk
)
assert
len
(
chunks
)
==
1
assert
len
(
chunks
)
==
1
...
...
components/src/dynamo/vllm/tests/multimodal_utils/test_vllm_prefill_worker_utils.py
0 → 100644
View file @
c82fe888
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for load_multimodal_embeddings in prefill_worker_utils."""
from
unittest.mock
import
AsyncMock
,
patch
import
pytest
import
torch
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
CachedEmbedding
,
MultimodalEmbeddingCacheManager
,
)
from
dynamo.vllm.multimodal_utils
import
prefill_worker_utils
as
mod
from
dynamo.vllm.multimodal_utils.protocol
import
MultiModalGroup
,
MultiModalInput
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
vllm
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
multimodal
,
]
MODEL
=
"test-model"
DTYPE
=
torch
.
float16
class
TestLoadMultimodalEmbeddings
:
@
pytest
.
mark
.
asyncio
async
def
test_all_cached
(
self
):
"""All URLs cached -> no encode worker call, returns accumulated mm_data."""
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
1024
*
1024
)
tensor
=
torch
.
randn
(
1
,
10
,
dtype
=
DTYPE
)
grid
=
[[
1
,
2
,
3
]]
url
=
"http://img1.png"
key
=
mod
.
get_embedding_hash
(
url
)
cache
.
set
(
key
,
CachedEmbedding
(
tensor
=
tensor
,
image_grid_thw
=
grid
))
with
patch
.
object
(
mod
,
"_fetch_from_encode_workers"
,
new_callable
=
AsyncMock
,
)
as
mock_fetch
:
mm_data
=
await
mod
.
load_multimodal_embeddings
(
AsyncMock
(),
[
url
],
"req-1"
,
None
,
model
=
MODEL
,
embeddings_dtype
=
DTYPE
,
cache
=
cache
,
)
mock_fetch
.
assert_not_awaited
()
assert
torch
.
equal
(
mm_data
[
"image"
],
tensor
)
@
pytest
.
mark
.
asyncio
async
def
test_all_uncached_with_cache
(
self
):
"""All URLs uncached with cache -> encode worker call, results cached."""
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
1024
*
1024
)
url
=
"http://img1.png"
tensor
=
torch
.
randn
(
1
,
10
,
dtype
=
DTYPE
)
fake_group
=
MultiModalGroup
(
multimodal_input
=
MultiModalInput
(),
image_grid_thw
=
[[
1
,
2
,
3
]],
loaded_embedding
=
tensor
,
)
with
patch
.
object
(
mod
,
"_fetch_from_encode_workers"
,
new_callable
=
AsyncMock
,
return_value
=
([
fake_group
],
None
),
)
as
mock_fetch
:
mm_data
=
await
mod
.
load_multimodal_embeddings
(
AsyncMock
(),
[
url
],
"req-1"
,
None
,
model
=
MODEL
,
embeddings_dtype
=
DTYPE
,
cache
=
cache
,
)
mock_fetch
.
assert_awaited_once
()
assert
torch
.
equal
(
mm_data
[
"image"
],
tensor
)
key
=
mod
.
get_embedding_hash
(
url
)
cached
=
cache
.
get
(
key
)
assert
cached
is
not
None
assert
torch
.
equal
(
cached
.
tensor
,
tensor
)
@
pytest
.
mark
.
asyncio
async
def
test_no_cache
(
self
):
"""Without cache -> all URLs go to encode workers."""
url
=
"http://img1.png"
tensor
=
torch
.
randn
(
1
,
10
,
dtype
=
DTYPE
)
fake_group
=
MultiModalGroup
(
multimodal_input
=
MultiModalInput
(),
loaded_embedding
=
tensor
,
)
with
patch
.
object
(
mod
,
"_fetch_from_encode_workers"
,
new_callable
=
AsyncMock
,
return_value
=
([
fake_group
],
None
),
)
as
mock_fetch
:
mm_data
=
await
mod
.
load_multimodal_embeddings
(
AsyncMock
(),
[
url
],
"req-1"
,
None
,
model
=
MODEL
,
embeddings_dtype
=
DTYPE
,
cache
=
None
,
)
mock_fetch
.
assert_awaited_once
()
assert
torch
.
equal
(
mm_data
[
"image"
],
tensor
)
@
pytest
.
mark
.
asyncio
async
def
test_mixed_cache
(
self
):
"""Mixed cache hits/misses -> only misses sent to encode workers."""
cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
1024
*
1024
)
url_cached
=
"http://cached.png"
url_miss
=
"http://miss.png"
cached_tensor
=
torch
.
randn
(
1
,
10
,
dtype
=
DTYPE
)
miss_tensor
=
torch
.
randn
(
1
,
10
,
dtype
=
DTYPE
)
key
=
mod
.
get_embedding_hash
(
url_cached
)
cache
.
set
(
key
,
CachedEmbedding
(
tensor
=
cached_tensor
,
image_grid_thw
=
None
))
fake_group
=
MultiModalGroup
(
multimodal_input
=
MultiModalInput
(),
image_grid_thw
=
None
,
loaded_embedding
=
miss_tensor
,
)
with
patch
.
object
(
mod
,
"_fetch_from_encode_workers"
,
new_callable
=
AsyncMock
,
return_value
=
([
fake_group
],
None
),
)
as
mock_fetch
:
mm_data
=
await
mod
.
load_multimodal_embeddings
(
AsyncMock
(),
[
url_cached
,
url_miss
],
"req-1"
,
None
,
model
=
MODEL
,
embeddings_dtype
=
DTYPE
,
cache
=
cache
,
)
mock_fetch
.
assert_awaited_once
()
call_args
=
mock_fetch
.
call_args
assert
call_args
[
0
][
1
]
==
[
url_miss
]
expected
=
torch
.
cat
((
cached_tensor
,
miss_tensor
))
assert
torch
.
equal
(
mm_data
[
"image"
],
expected
)
examples/backends/vllm/launch/disagg_multimodal_e_pd.sh
0 → 100755
View file @
c82fe888
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set
-e
trap
'echo Cleaning up...; kill 0'
EXIT
# Default values
MODEL_NAME
=
"Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
SINGLE_GPU
=
false
# Parse command line arguments
# All extra arguments are passed through to the PD worker's dynamo.vllm
# (which routes them to Dynamo or vLLM as appropriate).
EXTRA_PD_ARGS
=()
while
[[
$#
-gt
0
]]
;
do
case
$1
in
--model
)
MODEL_NAME
=
$2
shift
2
;;
--single-gpu
)
SINGLE_GPU
=
true
shift
;;
-h
|
--help
)
echo
"Usage:
$0
[OPTIONS] [EXTRA_ARGS...]"
echo
""
echo
"Disaggregated multimodal serving with separate Encode and aggregated PD worker"
echo
""
echo
"Options:"
echo
" --model <model_name> Specify the VLM model to use (default:
$MODEL_NAME
)"
echo
" LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates"
echo
" --single-gpu Run encode and PD workers on the same GPU (for small models, e.g. 2B)"
echo
" -h, --help Show this help message"
echo
""
echo
"All additional arguments are passed through to the PD worker's dynamo.vllm."
echo
"Dynamo args (e.g. --multimodal-embedding-cache-capacity-gb) and"
echo
"vLLM engine args (e.g. --no-enable-prefix-caching) are automatically routed."
echo
""
echo
"Examples:"
echo
"
$0
--model llava-hf/llava-1.5-7b-hf"
echo
"
$0
--model microsoft/Phi-3.5-vision-instruct"
echo
"
$0
--model Qwen/Qwen2.5-VL-7B-Instruct"
echo
"
$0
--no-enable-prefix-caching --multimodal-embedding-cache-capacity-gb 2"
echo
"
$0
--model Qwen/Qwen2-VL-2B-Instruct --single-gpu"
echo
""
exit
0
;;
*
)
EXTRA_PD_ARGS+
=(
"
$1
"
)
shift
;;
esac
done
PD_MAX_MODEL_LEN
=
"16384"
echo
"=================================================="
echo
"Disaggregated Multimodal Serving (E + PD)"
echo
"=================================================="
echo
"Model:
$MODEL_NAME
"
echo
"=================================================="
# Start frontend (no router mode)
echo
"Starting frontend..."
python
-m
dynamo.frontend &
EXTRA_ARGS
=
""
# Embedding transfer: 1 = local file (safetensors), 0 = NIXL RDMA
export
TRANSFER_LOCAL
=
${
TRANSFER_LOCAL
:-
1
}
# GPU assignments (override via environment variables)
if
[[
"
$SINGLE_GPU
"
==
"true"
]]
;
then
DYN_ENCODE_WORKER_GPU
=
${
DYN_ENCODE_WORKER_GPU
:-
0
}
DYN_PD_WORKER_GPU
=
${
DYN_PD_WORKER_GPU
:-
0
}
DYN_ENCODE_GPU_MEM
=
${
DYN_ENCODE_GPU_MEM
:-
0
.4
}
DYN_PD_GPU_MEM
=
${
DYN_PD_GPU_MEM
:-
0
.4
}
EXTRA_ARGS
=
"--enforce-eager"
else
DYN_ENCODE_WORKER_GPU
=
${
DYN_ENCODE_WORKER_GPU
:-
1
}
DYN_PD_WORKER_GPU
=
${
DYN_PD_WORKER_GPU
:-
2
}
DYN_ENCODE_GPU_MEM
=
${
DYN_ENCODE_GPU_MEM
:-
0
.9
}
DYN_PD_GPU_MEM
=
${
DYN_PD_GPU_MEM
:-
0
.9
}
fi
# Start encode worker
echo
"Starting encode worker on GPU
$DYN_ENCODE_WORKER_GPU
(GPU mem:
$DYN_ENCODE_GPU_MEM
)..."
CUDA_VISIBLE_DEVICES
=
$DYN_ENCODE_WORKER_GPU
\
python
-m
dynamo.vllm
\
--multimodal-encode-worker
\
--enable-multimodal
\
--model
"
$MODEL_NAME
"
\
--gpu-memory-utilization
"
$DYN_ENCODE_GPU_MEM
"
\
$EXTRA_ARGS
&
# Start PD worker (aggregated prefill+decode, routes to encoder for embeddings)
echo
"Starting PD worker on GPU
$DYN_PD_WORKER_GPU
(GPU mem:
$DYN_PD_GPU_MEM
)..."
CUDA_VISIBLE_DEVICES
=
$DYN_PD_WORKER_GPU
\
python
-m
dynamo.vllm
\
--route-to-encoder
\
--multimodal-worker
\
--enable-multimodal
\
--enable-mm-embeds
\
--model
"
$MODEL_NAME
"
\
--max-model-len
"
$PD_MAX_MODEL_LEN
"
\
--gpu-memory-utilization
"
$DYN_PD_GPU_MEM
"
\
$EXTRA_ARGS
\
"
${
EXTRA_PD_ARGS
[@]
}
"
&
echo
"=================================================="
echo
"All components started. Waiting for initialization..."
echo
"=================================================="
# Wait for all background processes to complete
wait
tests/serve/conftest.py
View file @
c82fe888
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
import
os
from
io
import
BytesIO
import
pytest
import
pytest
from
pytest_httpserver
import
HTTPServer
from
pytest_httpserver
import
HTTPServer
...
@@ -17,6 +18,28 @@ MULTIMODAL_IMG_PATH = os.path.join(
...
@@ -17,6 +18,28 @@ MULTIMODAL_IMG_PATH = os.path.join(
MULTIMODAL_IMG_URL
=
f
"http://localhost:
{
IMAGE_SERVER_PORT
}
/llm-graphic.png"
MULTIMODAL_IMG_URL
=
f
"http://localhost:
{
IMAGE_SERVER_PORT
}
/llm-graphic.png"
# Git LFS pointer files start with "version "; serve a real PNG when the asset is not pulled.
def
get_multimodal_test_image_bytes
()
->
bytes
:
"""Return valid PNG bytes for /llm-graphic.png (file or minimal fallback)."""
if
os
.
path
.
isfile
(
MULTIMODAL_IMG_PATH
):
with
open
(
MULTIMODAL_IMG_PATH
,
"rb"
)
as
f
:
data
=
f
.
read
()
if
not
data
.
startswith
(
b
"version "
):
# GitHub path
return
data
# Local path where we cannot retrieve the above .png file
# Lazy import so conftest loads in environments that don't have Pillow (e.g. pre-commit).
from
PIL
import
Image
buf
=
BytesIO
()
# TODO: differerent models / tests may expect different colors. Need to reconcicle
# code to support all cases locally if needed.
Image
.
new
(
"RGB"
,
(
2
,
2
),
color
=
"green"
).
save
(
buf
,
format
=
"PNG"
)
return
buf
.
getvalue
()
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
httpserver_listen_address
():
def
httpserver_listen_address
():
return
(
"127.0.0.1"
,
IMAGE_SERVER_PORT
)
return
(
"127.0.0.1"
,
IMAGE_SERVER_PORT
)
...
@@ -33,15 +56,14 @@ def image_server(httpserver: HTTPServer):
...
@@ -33,15 +56,14 @@ def image_server(httpserver: HTTPServer):
Currently serves:
Currently serves:
- /llm-graphic.png - LLM diagram image for multimodal tests
- /llm-graphic.png - LLM diagram image for multimodal tests
(or a minimal PNG if the file is a Git LFS pointer / not pulled)
Usage:
Usage:
def test_multimodal(image_server):
def test_multimodal(image_server):
url = "http://localhost:8765/llm-graphic.png"
url = "http://localhost:8765/llm-graphic.png"
# ... use url in your test payload
# ... use url in your test payload
"""
"""
# Load LLM graphic image from shared test data
image_data
=
get_multimodal_test_image_bytes
()
with
open
(
MULTIMODAL_IMG_PATH
,
"rb"
)
as
f
:
image_data
=
f
.
read
()
# Configure server endpoint
# Configure server endpoint
httpserver
.
expect_request
(
"/llm-graphic.png"
).
respond_with_data
(
httpserver
.
expect_request
(
"/llm-graphic.png"
).
respond_with_data
(
...
...
tests/serve/test_vllm.py
View file @
c82fe888
...
@@ -16,7 +16,7 @@ from tests.serve.common import (
...
@@ -16,7 +16,7 @@ from tests.serve.common import (
params_with_model_mark
,
params_with_model_mark
,
run_serve_deployment
,
run_serve_deployment
,
)
)
from
tests.serve.conftest
import
MULTIMODAL_IMG_
PATH
,
MULTIMODAL_IMG_URL
from
tests.serve.conftest
import
MULTIMODAL_IMG_
URL
,
get_multimodal_test_image_bytes
from
tests.serve.lora_utils
import
MinioLoraConfig
from
tests.serve.lora_utils
import
MinioLoraConfig
from
tests.utils.constants
import
DefaultPort
from
tests.utils.constants
import
DefaultPort
from
tests.utils.engine_process
import
EngineConfig
from
tests.utils.engine_process
import
EngineConfig
...
@@ -276,37 +276,34 @@ vllm_configs = {
...
@@ -276,37 +276,34 @@ vllm_configs = {
completion_payload_default
(),
completion_payload_default
(),
],
],
),
),
# The original script is misleading agg_multimodal_epd.sh is actually a disagg
"multimodal_disagg_qwen2vl_2b_e_pd"
:
VLLMConfig
(
# case which uses disgg encoder. We are bringing this test back shortly
name
=
"multimodal_disagg_qwen2vl_2b_e_pd"
,
# TODO(qiwa): enable this in https://github.com/ai-dynamo/dynamo/pull/6061/
directory
=
vllm_dir
,
# "multimodal_agg_qwen2vl_2b_epd": VLLMConfig(
script_name
=
"disagg_multimodal_e_pd.sh"
,
# name="multimodal_agg_qwen2vl_2b_epd",
marks
=
[
pytest
.
mark
.
gpu_1
,
pytest
.
mark
.
pre_merge
],
# directory=vllm_dir,
model
=
"Qwen/Qwen2-VL-2B-Instruct"
,
# script_name="agg_multimodal_epd.sh",
script_args
=
[
"--model"
,
"Qwen/Qwen2-VL-2B-Instruct"
,
"--single-gpu"
],
# marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
request_payloads
=
[
# model="Qwen/Qwen2-VL-2B-Instruct",
chat_payload
(
# script_args=["--model", "Qwen/Qwen2-VL-2B-Instruct", "--single-gpu"],
[
# request_payloads=[
{
# chat_payload(
"type"
:
"text"
,
# [
"text"
:
"What colors are in the following image? Respond only with the colors."
,
# {
},
# "type": "text",
{
# "text": "What colors are in the following image? Respond only with the colors.",
"type"
:
"image_url"
,
# },
"image_url"
:
{
"url"
:
MULTIMODAL_IMG_URL
},
# {
},
# "type": "image_url",
],
# "image_url": {"url": MULTIMODAL_IMG_URL},
repeat_count
=
1
,
# },
# With proper prompt templating, the model actually only returns "green",
# ],
# verified behavior with native vLLM.
# repeat_count=1,
expected_response
=
[
"green"
],
# # With proper prompt templating, the model actually only returns "green",
temperature
=
0.0
,
# # verified behavior with native vLLM.
max_tokens
=
100
,
# expected_response=["green"],
)
# temperature=0.0,
],
# max_tokens=100,
),
# )
# ],
# ),
"multimodal_agg_frontend_decoding"
:
VLLMConfig
(
"multimodal_agg_frontend_decoding"
:
VLLMConfig
(
name
=
"multimodal_agg_frontend_decoding"
,
name
=
"multimodal_agg_frontend_decoding"
,
directory
=
vllm_dir
,
directory
=
vllm_dir
,
...
@@ -755,9 +752,8 @@ def test_multimodal_b64(
...
@@ -755,9 +752,8 @@ def test_multimodal_b64(
This test is separate because it loads the required image at runtime
This test is separate because it loads the required image at runtime
(not collection time), ensuring it only fails when actually executed.
(not collection time), ensuring it only fails when actually executed.
"""
"""
# Load B64 image at test execution time
# Load B64 image at test execution time (uses real PNG even if MULTIMODAL_IMG is LFS pointer)
with
open
(
MULTIMODAL_IMG_PATH
,
"rb"
)
as
f
:
b64_img
=
base64
.
b64encode
(
get_multimodal_test_image_bytes
()).
decode
()
b64_img
=
base64
.
b64encode
(
f
.
read
()).
decode
()
# Create payload with B64 image
# Create payload with B64 image
b64_payload
=
chat_payload
(
b64_payload
=
chat_payload
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment