Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f46498f2
Unverified
Commit
f46498f2
authored
Apr 02, 2026
by
Wang, Yi
Committed by
GitHub
Apr 02, 2026
Browse files
feat: add sglang embeding cache (#7674)
Signed-off-by:
Wang, Yi
<
yi.a.wang@intel.com
>
parent
93530057
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
245 additions
and
45 deletions
+245
-45
components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py
...lang/request_handlers/multimodal/encode_worker_handler.py
+141
-45
components/src/dynamo/sglang/tests/test_sglang_multimodal_embedding_cache.py
...mo/sglang/tests/test_sglang_multimodal_embedding_cache.py
+104
-0
No files found.
components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py
View file @
f46498f2
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
asyncio
import
hashlib
import
json
import
json
import
logging
import
logging
from
typing
import
Any
,
AsyncIterator
,
Dict
,
Optional
from
typing
import
Any
,
AsyncIterator
,
Dict
,
Optional
...
@@ -17,6 +18,10 @@ from sglang.srt.parser.conversation import chat_templates
...
@@ -17,6 +18,10 @@ from sglang.srt.parser.conversation import chat_templates
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
dynamo._core
import
Client
,
Context
from
dynamo._core
import
Client
,
Context
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
CachedEmbedding
,
MultimodalEmbeddingCacheManager
,
)
from
dynamo.common.multimodal
import
EMBEDDING_SENDER_FACTORIES
from
dynamo.common.multimodal
import
EMBEDDING_SENDER_FACTORIES
from
dynamo.common.utils
import
nvtx_utils
as
_nvtx
from
dynamo.common.utils
import
nvtx_utils
as
_nvtx
from
dynamo.sglang.args
import
Config
from
dynamo.sglang.args
import
Config
...
@@ -123,9 +128,133 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
...
@@ -123,9 +128,133 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
)
)
self
.
embedding_sender
=
sender
()
self
.
embedding_sender
=
sender
()
# Optional CPU-side LRU embedding cache
self
.
_embedding_cache
:
MultimodalEmbeddingCacheManager
|
None
=
None
capacity_gb
=
config
.
dynamo_args
.
multimodal_embedding_cache_capacity_gb
if
capacity_gb
>
0
:
capacity_bytes
=
int
(
capacity_gb
*
1024
**
3
)
self
.
_embedding_cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
)
logger
.
info
(
"Multimodal embedding cache enabled: %.2f GB"
,
capacity_gb
)
def
cleanup
(
self
)
->
None
:
def
cleanup
(
self
)
->
None
:
pass
pass
@
staticmethod
def
_url_hash
(
url
:
str
)
->
str
:
"""Stable blake2b hash of an image URL, used as embedding cache key."""
return
hashlib
.
blake2b
(
url
.
encode
(),
digest_size
=
32
).
hexdigest
()
@
staticmethod
def
_split_token_counts
(
grid_list
:
list
,
total_tokens
:
int
)
->
list
[
int
]:
"""Compute per-image embedding token counts from image_grid_thw shapes.
Each entry in grid_list is [t, h, w]. The spatial grid size (h*w) is
proportional to the number of tokens for that image. We infer the shared
merge factor from the ratio of total grid tokens to total embedding tokens,
then apply it per image.
"""
if
total_tokens
<=
0
:
raise
ValueError
(
"Invalid token count for embeddings"
)
grid_sizes
=
[]
for
image_grid_thw
in
grid_list
:
if
not
isinstance
(
image_grid_thw
,
list
)
or
len
(
image_grid_thw
)
!=
3
:
raise
ValueError
(
f
"Invalid image_grid_thw:
{
image_grid_thw
}
"
)
grid_sizes
.
append
(
int
(
image_grid_thw
[
1
]
*
image_grid_thw
[
2
]))
total_grid_tokens
=
sum
(
grid_sizes
)
if
total_grid_tokens
<=
0
:
raise
ValueError
(
"Invalid grid statistics for embeddings"
)
if
total_grid_tokens
%
total_tokens
!=
0
:
raise
ValueError
(
"Cannot infer merge factor: grid token total is not divisible "
"by embedding token total"
)
merge_factor
=
total_grid_tokens
//
total_tokens
token_counts
=
[]
for
grid_count
in
grid_sizes
:
if
grid_count
%
merge_factor
!=
0
:
raise
ValueError
(
"Cannot split embeddings: per-image grid token count not "
"divisible by inferred merge factor"
)
token_counts
.
append
(
grid_count
//
merge_factor
)
if
sum
(
token_counts
)
!=
total_tokens
:
raise
ValueError
(
"Cannot split embeddings: per-image token counts do not match "
"embedding token total"
)
return
token_counts
async
def
_encode_with_cache
(
self
,
image_urls
:
list
[
str
]
)
->
tuple
[
Any
,
torch
.
Tensor
]:
"""Cache-aware vision encoding.
Checks the CPU LRU cache per URL. Uncached URLs are batch-encoded,
split per image, stored in cache, then reassembled with the cached
hits in the original URL order.
Returns the same (image_grid_dim, embeddings) shape as
``self.encoder._encode()``.
"""
assert
self
.
_embedding_cache
is
not
None
cached
:
dict
[
int
,
CachedEmbedding
]
=
{}
uncached_indices
:
list
[
int
]
=
[]
uncached_urls
:
list
[
str
]
=
[]
for
i
,
url
in
enumerate
(
image_urls
):
hit
=
self
.
_embedding_cache
.
get
(
self
.
_url_hash
(
url
))
if
hit
is
not
None
:
logger
.
debug
(
"Embedding cache hit for URL index %d"
,
i
)
cached
[
i
]
=
hit
else
:
uncached_indices
.
append
(
i
)
uncached_urls
.
append
(
url
)
new_entries
:
dict
[
int
,
CachedEmbedding
]
=
{}
# SGLang's _encode outputs are already on CPU; use CPU as target for consistency
target_device
=
torch
.
device
(
"cpu"
)
if
uncached_urls
:
grid_dim
,
new_embeddings
=
await
self
.
encoder
.
_encode
(
uncached_urls
)
# Verify SGLang output is on CPU as expected
if
new_embeddings
.
device
!=
target_device
:
logger
.
warning
(
f
"SGLang _encode returned embeddings on
{
new_embeddings
.
device
}
, "
f
"expected CPU. Moving to CPU."
)
new_embeddings
=
new_embeddings
.
to
(
target_device
)
grid_list
:
list
=
(
grid_dim
.
tolist
()
if
isinstance
(
grid_dim
,
torch
.
Tensor
)
else
grid_dim
)
if
not
(
isinstance
(
new_embeddings
,
torch
.
Tensor
)
and
new_embeddings
.
ndim
==
2
):
raise
ValueError
(
f
"Unsupported embeddings type from encoder:
{
type
(
new_embeddings
)
}
"
)
token_counts
=
self
.
_split_token_counts
(
grid_list
,
new_embeddings
.
shape
[
0
])
split_tensors
=
torch
.
split
(
new_embeddings
,
token_counts
,
dim
=
0
)
for
orig_idx
,
url
,
tensor
,
grid_thw
in
zip
(
uncached_indices
,
uncached_urls
,
split_tensors
,
grid_list
):
entry
=
CachedEmbedding
(
tensor
=
tensor
.
contiguous
(),
image_grid_thw
=
grid_thw
,
)
self
.
_embedding_cache
.
set
(
self
.
_url_hash
(
url
),
entry
)
new_entries
[
orig_idx
]
=
entry
# Reassemble results in original URL order
all_grid_thw
:
list
=
[]
embedding_parts
:
list
[
torch
.
Tensor
]
=
[]
for
i
in
range
(
len
(
image_urls
)):
entry
=
cached
[
i
]
if
i
in
cached
else
new_entries
[
i
]
all_grid_thw
.
append
(
entry
.
image_grid_thw
)
embedding_parts
.
append
(
entry
.
tensor
)
full_embeddings
=
torch
.
cat
(
embedding_parts
,
dim
=
0
)
return
torch
.
tensor
(
all_grid_thw
),
full_embeddings
def
_extract_image_urls
(
self
,
request
:
Dict
[
str
,
Any
])
->
list
[
str
]:
def
_extract_image_urls
(
self
,
request
:
Dict
[
str
,
Any
])
->
list
[
str
]:
"""
"""
Extract image URLs from the multi_modal_data field of a PreprocessedRequest.
Extract image URLs from the multi_modal_data field of a PreprocessedRequest.
...
@@ -200,6 +329,12 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
...
@@ -200,6 +329,12 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
try
:
try
:
with
_nvtx
.
annotate
(
"mm:enc:vision_encode"
,
color
=
"red"
):
with
_nvtx
.
annotate
(
"mm:enc:vision_encode"
,
color
=
"red"
):
if
self
.
_embedding_cache
is
not
None
:
(
image_grid_dim
,
precomputed_embeddings
,
)
=
await
self
.
_encode_with_cache
(
image_urls
)
else
:
image_grid_dim
,
precomputed_embeddings
=
await
self
.
encoder
.
_encode
(
image_grid_dim
,
precomputed_embeddings
=
await
self
.
encoder
.
_encode
(
image_urls
image_urls
)
)
...
@@ -213,54 +348,15 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
...
@@ -213,54 +348,15 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
if
len
(
image_grid_thw_list
)
!=
len
(
multimodal_groups
):
if
len
(
image_grid_thw_list
)
!=
len
(
multimodal_groups
):
raise
ValueError
(
"image_grid_thw size mismatch"
)
raise
ValueError
(
"image_grid_thw size mismatch"
)
def
_build_token_counts
(
total_tokens
:
int
)
->
list
[
int
]:
if
total_tokens
<=
0
:
raise
ValueError
(
"Invalid token statistics for embeddings"
)
# image_grid_thw is [t, h, w]. We derive per-item relative sizes
# from spatial grid (h * w), then infer merge factor
# from the total embedding token count.
grid_sizes
=
[]
for
image_grid_thw
in
image_grid_thw_list
:
if
not
isinstance
(
image_grid_thw
,
list
)
or
len
(
image_grid_thw
)
!=
3
:
raise
ValueError
(
"Cannot split embeddings: invalid image_grid_thw"
)
grid_sizes
.
append
(
int
(
image_grid_thw
[
1
]
*
image_grid_thw
[
2
]))
total_grid_tokens
=
sum
(
grid_sizes
)
if
total_grid_tokens
<=
0
:
raise
ValueError
(
"Invalid grid statistics for embeddings"
)
if
total_grid_tokens
%
total_tokens
!=
0
:
raise
ValueError
(
"Cannot infer merge factor: grid token total is not divisible by embedding token total"
)
merge_factor
=
total_grid_tokens
//
total_tokens
token_counts
=
[]
for
grid_count
in
grid_sizes
:
if
grid_count
%
merge_factor
!=
0
:
raise
ValueError
(
"Cannot split embeddings: per-image grid token count not divisible by inferred merge factor"
)
token_counts
.
append
(
grid_count
//
merge_factor
)
if
sum
(
token_counts
)
!=
total_tokens
:
raise
ValueError
(
"Cannot split embeddings: per-image token counts do not match embedding token total"
)
return
token_counts
if
isinstance
(
precomputed_embeddings
,
torch
.
Tensor
):
if
isinstance
(
precomputed_embeddings
,
torch
.
Tensor
):
if
precomputed_embeddings
.
ndim
!=
2
:
if
precomputed_embeddings
.
ndim
!=
2
:
raise
ValueError
(
raise
ValueError
(
"Unsupported embeddings tensor rank from encoder: "
"Unsupported embeddings tensor rank from encoder: "
f
"
{
precomputed_embeddings
.
ndim
}
. Expected 2D [tokens, hidden]."
f
"
{
precomputed_embeddings
.
ndim
}
. Expected 2D [tokens, hidden]."
)
)
token_counts
=
self
.
_split_token_counts
(
token_counts
=
_build_token_counts
(
precomputed_embeddings
.
shape
[
0
])
image_grid_thw_list
,
precomputed_embeddings
.
shape
[
0
]
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Unsupported embeddings type from encoder: "
"Unsupported embeddings type from encoder: "
...
...
components/src/dynamo/sglang/tests/test_sglang_multimodal_embedding_cache.py
0 → 100644
View file @
f46498f2
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for SGLang multimodal embedding cache behavior."""
from
types
import
SimpleNamespace
from
unittest.mock
import
AsyncMock
import
pytest
import
torch
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
CachedEmbedding
,
MultimodalEmbeddingCacheManager
,
)
from
dynamo.sglang.request_handlers.multimodal.encode_worker_handler
import
(
MultimodalEncodeWorkerHandler
,
)
pytestmark
=
[
pytest
.
mark
.
unit
,
pytest
.
mark
.
sglang
,
pytest
.
mark
.
gpu_1
,
# sglang tests run on GPU-enabled workers
pytest
.
mark
.
max_vram_gib
(
0
),
pytest
.
mark
.
pre_merge
,
]
@
pytest
.
fixture
def
cache_handler
()
->
MultimodalEncodeWorkerHandler
:
"""Create a lightweight handler instance for cache-path unit tests."""
handler
=
MultimodalEncodeWorkerHandler
.
__new__
(
MultimodalEncodeWorkerHandler
)
handler
.
_embedding_cache
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
=
32
*
1024
*
1024
)
handler
.
encoder
=
SimpleNamespace
(
_encode
=
AsyncMock
())
return
handler
@
pytest
.
mark
.
asyncio
async
def
test_encode_with_cache_partial_hit_and_reuse
(
cache_handler
:
MultimodalEncodeWorkerHandler
,
)
->
None
:
"""Partial-hit should encode only misses and preserve URL order in output."""
urls
=
[
"http://example.com/a.jpg"
,
"http://example.com/b.jpg"
,
"http://example.com/c.jpg"
,
]
# Pre-cache url[1] (4 tokens x 3 hidden)
cached_tensor
=
torch
.
full
((
4
,
3
),
fill_value
=-
1.0
)
cache_handler
.
_embedding_cache
.
set
(
cache_handler
.
_url_hash
(
urls
[
1
]),
CachedEmbedding
(
tensor
=
cached_tensor
,
image_grid_thw
=
[
1
,
2
,
2
]),
)
# Encode only misses url[0], url[2]: token counts [8, 4]
encoded
=
torch
.
arange
(
12
*
3
,
dtype
=
torch
.
float32
).
reshape
(
12
,
3
)
cache_handler
.
encoder
.
_encode
.
return_value
=
(
torch
.
tensor
([[
1
,
2
,
4
],
[
1
,
2
,
2
]]),
encoded
,
)
grid
,
full_embeddings
=
await
cache_handler
.
_encode_with_cache
(
urls
)
# Encoder called once for uncached URLs only
cache_handler
.
encoder
.
_encode
.
assert_awaited_once_with
([
urls
[
0
],
urls
[
2
]])
# Order should match original URL order: a(8), b(4 cached), c(4)
assert
grid
.
tolist
()
==
[[
1
,
2
,
4
],
[
1
,
2
,
2
],
[
1
,
2
,
2
]]
assert
torch
.
equal
(
full_embeddings
[:
8
],
encoded
[:
8
])
assert
torch
.
equal
(
full_embeddings
[
8
:
12
],
cached_tensor
)
assert
torch
.
equal
(
full_embeddings
[
12
:
16
],
encoded
[
8
:
12
])
# Second call should be all-cache hit: no additional encoder calls
grid2
,
full_embeddings2
=
await
cache_handler
.
_encode_with_cache
(
urls
)
assert
cache_handler
.
encoder
.
_encode
.
await_count
==
1
assert
grid2
.
tolist
()
==
grid
.
tolist
()
assert
torch
.
equal
(
full_embeddings2
,
full_embeddings
)
@
pytest
.
mark
.
asyncio
async
def
test_encode_with_cache_all_hit_no_remote_call
(
cache_handler
:
MultimodalEncodeWorkerHandler
,
)
->
None
:
"""All-cache-hit path should not call encoder at all."""
urls
=
[
"http://example.com/x.jpg"
,
"http://example.com/y.jpg"
]
x
=
torch
.
ones
(
2
,
3
)
y
=
torch
.
ones
(
1
,
3
)
*
9
cache_handler
.
_embedding_cache
.
set
(
cache_handler
.
_url_hash
(
urls
[
0
]),
CachedEmbedding
(
tensor
=
x
,
image_grid_thw
=
[
1
,
1
,
2
]),
)
cache_handler
.
_embedding_cache
.
set
(
cache_handler
.
_url_hash
(
urls
[
1
]),
CachedEmbedding
(
tensor
=
y
,
image_grid_thw
=
[
1
,
1
,
1
]),
)
grid
,
full_embeddings
=
await
cache_handler
.
_encode_with_cache
(
urls
)
cache_handler
.
encoder
.
_encode
.
assert_not_called
()
assert
grid
.
tolist
()
==
[[
1
,
1
,
2
],
[
1
,
1
,
1
]]
assert
torch
.
equal
(
full_embeddings
,
torch
.
cat
([
x
,
y
],
dim
=
0
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment