Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
a78a4260
"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "2b5655fd43fa45f002a532d0d8239c4b4d99ac71"
Unverified
Commit
a78a4260
authored
Feb 04, 2026
by
Qi Wang
Committed by
GitHub
Feb 04, 2026
Browse files
feat: use encoder cache in TRT-LLM EPD workflow (#5780)
parent
b12e6710
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
507 additions
and
65 deletions
+507
-65
components/src/dynamo/trtllm/multimodal/__init__.py
components/src/dynamo/trtllm/multimodal/__init__.py
+2
-0
components/src/dynamo/trtllm/multimodal/embedding_fetcher.py
components/src/dynamo/trtllm/multimodal/embedding_fetcher.py
+255
-0
components/src/dynamo/trtllm/request_handlers/handlers.py
components/src/dynamo/trtllm/request_handlers/handlers.py
+14
-63
components/src/dynamo/trtllm/tests/multimodal/test_trtllm_embedding_fetcher.py
.../trtllm/tests/multimodal/test_trtllm_embedding_fetcher.py
+134
-0
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_prefill_handler.py
...llm/tests/request_handlers/test_trtllm_prefill_handler.py
+102
-2
No files found.
components/src/dynamo/trtllm/multimodal/__init__.py
View file @
a78a4260
...
@@ -2,9 +2,11 @@
...
@@ -2,9 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
.cuda_ipc
import
extract_embeddings_from_handles
from
.cuda_ipc
import
extract_embeddings_from_handles
from
.embedding_fetcher
import
fetch_embeddings_from_encoder
from
.hasher
import
MultimodalHasher
from
.hasher
import
MultimodalHasher
__all__
=
[
__all__
=
[
"MultimodalHasher"
,
"MultimodalHasher"
,
"extract_embeddings_from_handles"
,
"extract_embeddings_from_handles"
,
"fetch_embeddings_from_encoder"
,
]
]
components/src/dynamo/trtllm/multimodal/embedding_fetcher.py
0 → 100644
View file @
a78a4260
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Embedding fetcher utilities for multimodal processing with caching.
Provides utility functions for fetching image embeddings from remote encoder
with per-URL caching support.
"""
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo.common.multimodal.async_encoder_cache
import
EncoderCacheManager
from
dynamo.trtllm.multimodal.cuda_ipc
import
extract_embeddings_from_handles
from
dynamo.trtllm.multimodal.hasher
import
MultimodalHasher
logger
=
logging
.
getLogger
(
__name__
)
async
def
fetch_embeddings_from_encoder
(
image_urls
:
List
[
str
],
request
:
Dict
[
str
,
Any
],
encode_client
:
Any
,
encoder_cache
:
Optional
[
EncoderCacheManager
]
=
None
,
)
->
Union
[
List
[
torch
.
Tensor
],
DisaggregatedParams
]:
"""
Fetch embeddings from remote encode worker.
Args:
image_urls: List of image URLs to encode (must not be empty)
request: Request dict (used for creating modified requests for caching)
encode_client: Client to call remote encode worker
encoder_cache: Optional cache for embeddings
Returns:
- List[torch.Tensor]: When using cache (CPU tensors from cache)
- DisaggregatedParams: When not using cache (contains CUDA IPC handles)
Raises:
ValueError: If image_urls is empty
"""
if
not
image_urls
:
raise
ValueError
(
"image_urls must not be empty"
)
logger
.
info
(
f
"fetch_embeddings_from_encoder: image_urls=
{
image_urls
}
"
)
if
encoder_cache
:
# Cache path: extract embeddings to CPU tensors
return
await
_fetch_embeddings_with_cache
(
image_urls
,
request
,
encoder_cache
,
lambda
req
:
_remote_encode_full_epd
(
req
,
encode_client
,
update_request_for_decode
=
False
),
)
else
:
# No cache: return DisaggregatedParams directly (no GPU→CPU extraction)
return
await
_remote_encode_full_epd
(
request
,
encode_client
,
update_request_for_decode
=
True
)
async
def
_remote_encode_full_epd
(
request
:
Dict
[
str
,
Any
],
encode_client
:
Any
,
update_request_for_decode
:
bool
=
True
,
)
->
DisaggregatedParams
:
"""
Call encode worker for full EPD flow.
Args:
request: Request dict
encode_client: Client to call remote encode worker
update_request_for_decode: If True, store EPD metadata in request
Returns:
DisaggregatedParams with multimodal_embedding_handles
Raises:
RuntimeError: If encode worker returns invalid response
"""
encode_response
=
None
async
for
res
in
await
encode_client
.
round_robin
(
request
):
encode_response
=
res
.
data
()
break
if
not
encode_response
:
raise
RuntimeError
(
"Did not receive a response from the encode worker."
)
if
"ep_disaggregated_params"
not
in
encode_response
:
raise
RuntimeError
(
"Encode response missing ep_disaggregated_params."
)
params_dict
=
encode_response
[
"ep_disaggregated_params"
]
if
params_dict
is
None
:
raise
RuntimeError
(
"ep_disaggregated_params is None."
)
# Store EPD metadata in request for decode worker (only when not using cache)
if
update_request_for_decode
:
if
"processed_prompt"
in
encode_response
:
request
[
"_epd_processed_prompt"
]
=
encode_response
[
"processed_prompt"
]
if
"prompt_token_ids"
in
encode_response
:
request
[
"_epd_prompt_token_ids"
]
=
encode_response
[
"prompt_token_ids"
]
return
DisaggregatedParams
(
**
params_dict
)
async
def
_fetch_embeddings_with_cache
(
image_urls
:
List
[
str
],
request
:
Dict
[
str
,
Any
],
cache
:
EncoderCacheManager
,
encode_fn
:
Callable
[[
Dict
[
str
,
Any
]],
DisaggregatedParams
],
)
->
List
[
torch
.
Tensor
]:
"""
Encode image URLs with per-URL caching and partial cache usage.
Checks cache for each URL. Cached embeddings are reused directly.
For uncached URLs, sends a single encode request for only those URLs,
then caches the results.
Args:
image_urls: List of image URLs to encode
request: Original request dict containing the images
cache: AsyncEncoderCache instance for caching embeddings
encode_fn: Async function that encodes a request and returns ep_disaggregated_params
Should accept a modified request dict with subset of URLs
Returns:
List of embedding tensors for all images in original order
"""
if
not
image_urls
:
raise
ValueError
(
"image_urls list is empty"
)
# Check cache for each URL
embeddings_with_index
=
[]
# List of (original_index, tensor)
uncached_urls
=
[]
uncached_indices
=
[]
uncached_hashes
=
[]
for
i
,
url
in
enumerate
(
image_urls
):
url_hash
=
MultimodalHasher
.
hash_bytes
(
url
.
encode
())
cached
=
cache
.
get
(
url_hash
)
if
cached
is
not
None
:
logger
.
info
(
f
"fetch_embeddings_with_cache: cache hit for URL:
{
url
}
"
)
embeddings_with_index
.
append
((
i
,
cached
))
else
:
logger
.
info
(
f
"fetch_embeddings_with_cache: cache miss for URL:
{
url
}
"
)
uncached_urls
.
append
(
url
)
uncached_indices
.
append
(
i
)
uncached_hashes
.
append
(
url_hash
)
# If all cached, return immediately
if
not
uncached_urls
:
logger
.
info
(
f
"fetch_embeddings_with_cache: all
{
len
(
image_urls
)
}
URLs cached"
)
embeddings_with_index
.
sort
(
key
=
lambda
x
:
x
[
0
])
tensors
=
[
t
for
_
,
t
in
embeddings_with_index
]
return
tensors
# Encode uncached URLs
logger
.
info
(
f
"fetch_embeddings_with_cache: encoding
{
len
(
uncached_urls
)
}
uncached URLs"
)
# Create modified request with only uncached URLs
modified_request
=
_create_request_with_urls
(
request
,
uncached_urls
)
# Call encode function
ep_disaggregated_params
=
await
encode_fn
(
modified_request
)
if
not
ep_disaggregated_params
:
raise
RuntimeError
(
"fetch_embeddings_with_cache: Failed to get ep_disaggregated_params"
)
# Extract handles from disaggregated params
handles
=
getattr
(
ep_disaggregated_params
,
"multimodal_embedding_handles"
,
None
)
if
not
handles
:
raise
RuntimeError
(
"fetch_embeddings_with_cache: No multimodal_embedding_handles in ep_disaggregated_params"
)
# Extract tensors from CUDA IPC handles
new_tensors
=
await
extract_embeddings_from_handles
(
handles
)
# Cache new tensors (reuse hashes computed during cache lookup)
for
url
,
url_hash
,
tensor
in
zip
(
uncached_urls
,
uncached_hashes
,
new_tensors
):
cache
.
set
(
url_hash
,
tensor
)
logger
.
info
(
f
"fetch_embeddings_with_cache: cached embedding for URL:
{
url
}
, shape:
{
tensor
.
shape
}
"
)
# Add new tensors to our list with their original indices
for
idx
,
tensor
in
zip
(
uncached_indices
,
new_tensors
):
embeddings_with_index
.
append
((
idx
,
tensor
))
# Sort by original order and return list
embeddings_with_index
.
sort
(
key
=
lambda
x
:
x
[
0
])
tensors
=
[
t
for
_
,
t
in
embeddings_with_index
]
return
tensors
def
_create_request_with_urls
(
original_request
:
Dict
[
str
,
Any
],
image_urls
:
List
[
str
]
)
->
Dict
[
str
,
Any
]:
"""
Create a modified request containing only specified image URLs.
Args:
original_request: Original request dict
image_urls: URLs to include in the modified request
Returns:
Modified request dict with filtered image URLs
"""
# Deep copy to avoid modifying original
import
copy
modified_request
=
copy
.
deepcopy
(
original_request
)
# Extract messages
messages
=
modified_request
.
get
(
"extra_args"
,
{}).
get
(
"messages"
,
modified_request
.
get
(
"messages"
,
[])
)
# Filter messages to only include specified URLs
filtered_messages
=
[]
for
message
in
messages
:
new_message
=
{
"role"
:
message
.
get
(
"role"
,
"user"
),
"content"
:
[]}
for
content
in
message
.
get
(
"content"
,
[]):
if
isinstance
(
content
,
dict
):
if
content
.
get
(
"type"
)
==
"image_url"
:
# Only include if URL is in our list
url
=
content
.
get
(
"image_url"
,
{}).
get
(
"url"
)
if
url
in
image_urls
:
new_message
[
"content"
].
append
(
content
)
elif
content
.
get
(
"type"
)
==
"text"
:
# Keep text content
new_message
[
"content"
].
append
(
content
)
elif
isinstance
(
content
,
str
):
new_message
[
"content"
].
append
(
content
)
if
new_message
[
"content"
]:
filtered_messages
.
append
(
new_message
)
# Update the request with filtered messages
if
"extra_args"
in
modified_request
:
modified_request
[
"extra_args"
][
"messages"
]
=
filtered_messages
else
:
modified_request
[
"messages"
]
=
filtered_messages
return
modified_request
components/src/dynamo/trtllm/request_handlers/handlers.py
View file @
a78a4260
...
@@ -4,12 +4,11 @@
...
@@ -4,12 +4,11 @@
import
logging
import
logging
from
typing
import
Optional
from
typing
import
Optional
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo._core
import
Context
from
dynamo._core
import
Context
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.trtllm.encode_helper
import
EncodeHelper
from
dynamo.trtllm.encode_helper
import
EncodeHelper
from
dynamo.trtllm.multimodal.embedding_fetcher
import
fetch_embeddings_from_encoder
from
dynamo.trtllm.request_handlers.handler_base
import
(
from
dynamo.trtllm.request_handlers.handler_base
import
(
HandlerBase
,
HandlerBase
,
RequestHandlerConfig
,
RequestHandlerConfig
,
...
@@ -109,29 +108,6 @@ class PrefillHandler(HandlerBase):
...
@@ -109,29 +108,6 @@ class PrefillHandler(HandlerBase):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
_encoder_cache
=
encoder_cache
self
.
_encoder_cache
=
encoder_cache
async
def
remote_encode_full_epd
(
self
,
request
:
dict
):
"""
Call encode worker for full EPD flow and unpack the response.
Args:
request: Request dict
Returns:
Encoder's DisaggregatedParams to be used by the prefill worker
"""
encode_response
=
None
async
for
res
in
await
self
.
encode_client
.
round_robin
(
request
):
encode_response
=
res
.
data
()
break
if
not
encode_response
:
raise
RuntimeError
(
"Did not receive a response from the encode worker."
)
ep_disaggregated_params
=
self
.
_unpack_full_epd_response
(
encode_response
,
request
)
return
ep_disaggregated_params
async
def
remote_encode_with_nixl
(
self
,
request
:
dict
):
async
def
remote_encode_with_nixl
(
self
,
request
:
dict
):
"""
"""
Call encode worker for NIXL flow to load embeddings and unpack the response.
Call encode worker for NIXL flow to load embeddings and unpack the response.
...
@@ -156,43 +132,6 @@ class PrefillHandler(HandlerBase):
...
@@ -156,43 +132,6 @@ class PrefillHandler(HandlerBase):
encode_response
,
self
.
connector
encode_response
,
self
.
connector
)
)
def
_unpack_full_epd_response
(
self
,
encode_response
:
dict
,
request
:
dict
)
->
Optional
[
DisaggregatedParams
]:
"""
Unpack encode worker response from full EPD flow.
Extracts DisaggregatedParams and stores EPD metadata in the request
for downstream processing (multimodal_processor, decode worker).
Args:
encode_response: Response dict from encode worker
request: Request dict to store metadata in (modified in-place)
Returns:
DisaggregatedParams if present in response, None otherwise
"""
if
"ep_disaggregated_params"
not
in
encode_response
:
return
None
params_dict
=
encode_response
[
"ep_disaggregated_params"
]
if
params_dict
is
None
:
return
None
# Reconstruct DisaggregatedParams object from dict
ep_disaggregated_params
=
DisaggregatedParams
(
**
params_dict
)
ep_disaggregated_params
.
request_type
=
"context_only"
# Store processed prompt from encoder (includes <image> tokens)
if
"processed_prompt"
in
encode_response
:
request
[
"_epd_processed_prompt"
]
=
encode_response
[
"processed_prompt"
]
# Store prompt_token_ids from encoder for decode worker
if
"prompt_token_ids"
in
encode_response
:
request
[
"_epd_prompt_token_ids"
]
=
encode_response
[
"prompt_token_ids"
]
return
ep_disaggregated_params
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
"""
"""
Prefill worker: process prompt and return disaggregated_params.
Prefill worker: process prompt and return disaggregated_params.
...
@@ -230,7 +169,19 @@ class PrefillHandler(HandlerBase):
...
@@ -230,7 +169,19 @@ class PrefillHandler(HandlerBase):
# Handle image URLs (full E-PD flow with MultimodalEncoder)
# Handle image URLs (full E-PD flow with MultimodalEncoder)
elif
image_urls
:
elif
image_urls
:
if
self
.
encode_client
:
if
self
.
encode_client
:
ep_disaggregated_params
=
await
self
.
remote_encode_full_epd
(
request
)
logging
.
info
(
f
"PrefillHandler: image_urls=
{
image_urls
}
"
)
result
=
await
fetch_embeddings_from_encoder
(
image_urls
,
request
,
self
.
encode_client
,
self
.
_encoder_cache
,
)
if
isinstance
(
result
,
list
):
# Cache path: got List[torch.Tensor]
embeddings_tensor
=
result
else
:
# No-cache path: got DisaggregatedParams
ep_disaggregated_params
=
result
# Normal flow: Generate the prefill response locally with embeddings
# Normal flow: Generate the prefill response locally with embeddings
response_count
=
0
response_count
=
0
...
...
components/src/dynamo/trtllm/tests/multimodal/test_trtllm_embedding_fetcher.py
0 → 100644
View file @
a78a4260
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for fetch_embeddings_from_encoder."""
from
typing
import
Any
from
unittest.mock
import
AsyncMock
,
patch
import
pytest
import
torch
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
from
dynamo.trtllm.multimodal.embedding_fetcher
import
fetch_embeddings_from_encoder
from
dynamo.trtllm.multimodal.hasher
import
MultimodalHasher
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
trtllm
,
pytest
.
mark
.
gpu_0
,
]
def
create_mock_encode_client
(
embeddings
:
list
[
torch
.
Tensor
],
processed_prompt
:
str
=
"prompt"
,
prompt_token_ids
:
list
[
int
]
|
None
=
None
,
)
->
AsyncMock
:
"""Create mock encode client that returns embeddings via CUDA IPC handles."""
class
MockResponse
:
def
data
(
self
):
return
{
"ep_disaggregated_params"
:
{
"multimodal_embedding_handles"
:
[
f
"h
{
i
}
"
for
i
in
range
(
len
(
embeddings
))
]
},
"processed_prompt"
:
processed_prompt
,
"prompt_token_ids"
:
prompt_token_ids
or
[
1
,
2
,
3
],
}
async
def
mock_round_robin
(
req
:
dict
[
str
,
Any
])
->
Any
:
async
def
gen
():
yield
MockResponse
()
return
gen
()
client
=
AsyncMock
()
client
.
round_robin
=
mock_round_robin
return
client
@
pytest
.
fixture
def
encoder_cache
()
->
EncoderCacheManager
:
"""Create encoder cache with 10MB capacity."""
return
EncoderCacheManager
(
capacity_bytes
=
10
*
1024
*
1024
)
class
TestFetchEmbeddingsFromEncoder
:
"""Tests for fetch_embeddings_from_encoder function."""
@
pytest
.
mark
.
asyncio
async
def
test_partial_cache_no_metadata_update
(
self
,
encoder_cache
):
"""Cache path: request NOT updated with EPD metadata."""
url1
,
url2
=
"http://example.com/img1.jpg"
,
"http://example.com/img2.jpg"
embedding1
,
embedding2
=
torch
.
ones
(
10
,
256
),
torch
.
ones
(
10
,
256
)
*
2
encoder_cache
.
set
(
MultimodalHasher
.
hash_bytes
(
url1
.
encode
()),
embedding1
)
request
:
dict
[
str
,
Any
]
=
{
"messages"
:
[]}
mock_client
=
create_mock_encode_client
([
embedding2
])
with
patch
(
"dynamo.trtllm.multimodal.embedding_fetcher.extract_embeddings_from_handles"
,
AsyncMock
(
return_value
=
[
embedding2
]),
):
result
=
await
fetch_embeddings_from_encoder
(
[
url1
,
url2
],
request
,
mock_client
,
encoder_cache
)
assert
len
(
result
)
==
2
assert
"_epd_processed_prompt"
not
in
request
@
pytest
.
mark
.
asyncio
async
def
test_all_cached_no_request_sent
(
self
,
encoder_cache
):
"""All cached: no encode request sent."""
url1
,
url2
=
"http://example.com/img1.jpg"
,
"http://example.com/img2.jpg"
embedding1
,
embedding2
=
torch
.
ones
(
10
,
256
),
torch
.
ones
(
10
,
256
)
*
2
encoder_cache
.
set
(
MultimodalHasher
.
hash_bytes
(
url1
.
encode
()),
embedding1
)
encoder_cache
.
set
(
MultimodalHasher
.
hash_bytes
(
url2
.
encode
()),
embedding2
)
async
def
should_not_call
(
req
:
dict
[
str
,
Any
])
->
None
:
raise
AssertionError
(
"Should not be called"
)
mock_client
=
AsyncMock
()
mock_client
.
round_robin
=
should_not_call
result
=
await
fetch_embeddings_from_encoder
(
[
url1
,
url2
],
{
"messages"
:
[]},
mock_client
,
encoder_cache
)
assert
len
(
result
)
==
2
assert
torch
.
equal
(
result
[
0
],
embedding1
)
@
pytest
.
mark
.
asyncio
async
def
test_no_cache_returns_disaggregated_params
(
self
):
"""No cache: returns DisaggregatedParams directly, request updated with metadata."""
request
:
dict
[
str
,
Any
]
=
{
"messages"
:
[]}
# Pass one embedding so mock generates one handle (DisaggregatedParams requires non-empty handles)
mock_client
=
create_mock_encode_client
(
[
torch
.
ones
(
10
,
256
)],
processed_prompt
=
"test <image>"
,
prompt_token_ids
=
[
10
,
20
],
)
result
=
await
fetch_embeddings_from_encoder
(
[
"http://example.com/img.jpg"
],
request
,
mock_client
,
encoder_cache
=
None
)
assert
isinstance
(
result
,
DisaggregatedParams
)
assert
request
[
"_epd_processed_prompt"
]
==
"test <image>"
assert
request
[
"_epd_prompt_token_ids"
]
==
[
10
,
20
]
@
pytest
.
mark
.
asyncio
async
def
test_empty_urls_raises_error
(
self
,
encoder_cache
):
"""Empty image_urls raises ValueError."""
mock_client
=
AsyncMock
()
with
pytest
.
raises
(
ValueError
,
match
=
"image_urls must not be empty"
):
await
fetch_embeddings_from_encoder
([],
{},
mock_client
,
encoder_cache
)
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_prefill_handler.py
View file @
a78a4260
...
@@ -3,15 +3,19 @@
...
@@ -3,15 +3,19 @@
"""Unit tests for PrefillHandler."""
"""Unit tests for PrefillHandler."""
from
unittest.mock
import
MagicMock
from
typing
import
Any
from
unittest.mock
import
AsyncMock
,
MagicMock
,
patch
import
pytest
import
pytest
import
torch
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo.trtllm.request_handlers.handlers
import
PrefillHandler
from
dynamo.trtllm.request_handlers.handlers
import
PrefillHandler
from
dynamo.trtllm.tests.utils
import
create_mock_request_handler_config
from
dynamo.trtllm.tests.utils
import
create_mock_request_handler_config
pytestmark
=
[
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
trtllm
,
pytest
.
mark
.
trtllm
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
gpu_0
,
]
]
...
@@ -29,10 +33,46 @@ def mock_encoder_cache():
...
@@ -29,10 +33,46 @@ def mock_encoder_cache():
cache
=
MagicMock
()
cache
=
MagicMock
()
cache
.
get
=
MagicMock
(
return_value
=
None
)
cache
.
get
=
MagicMock
(
return_value
=
None
)
cache
.
set
=
MagicMock
(
return_value
=
True
)
cache
.
set
=
MagicMock
(
return_value
=
True
)
cache
.
stats
=
{
"hits"
:
0
,
"misses"
:
0
,
"entries"
:
0
}
return
cache
return
cache
@
pytest
.
fixture
def
mock_context
():
"""Create a mock Context."""
ctx
=
MagicMock
()
ctx
.
id
=
MagicMock
(
return_value
=
"test-id"
)
ctx
.
is_stopped
=
MagicMock
(
return_value
=
False
)
ctx
.
is_killed
=
MagicMock
(
return_value
=
False
)
return
ctx
@
pytest
.
fixture
def
image_request
()
->
dict
[
str
,
Any
]:
"""Create a request with one image URL."""
return
{
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"http://example.com/image.jpg"
},
},
],
}
]
}
def
setup_multimodal_config
(
mock_config
):
"""Configure mock_config for multimodal requests."""
mock_config
.
multimodal_processor
=
MagicMock
()
mock_config
.
multimodal_processor
.
extract_prompt_and_media
=
MagicMock
(
return_value
=
(
"text"
,
[
"http://example.com/image.jpg"
],
[])
)
mock_config
.
encode_client
=
MagicMock
()
class
TestPrefillHandlerInit
:
class
TestPrefillHandlerInit
:
"""Tests for PrefillHandler initialization."""
"""Tests for PrefillHandler initialization."""
...
@@ -42,3 +82,63 @@ class TestPrefillHandlerInit:
...
@@ -42,3 +82,63 @@ class TestPrefillHandlerInit:
assert
handler
.
engine
==
mock_config
.
engine
assert
handler
.
engine
==
mock_config
.
engine
assert
handler
.
_encoder_cache
==
mock_encoder_cache
assert
handler
.
_encoder_cache
==
mock_encoder_cache
class
TestPrefillHandlerGenerate
:
"""Tests for PrefillHandler.generate method."""
@
pytest
.
mark
.
asyncio
async
def
test_embeddings_passed_to_generate_locally
(
self
,
mock_config
,
mock_encoder_cache
,
mock_context
,
image_request
):
"""Test embeddings from fetch_embeddings_from_encoder passed to generate_locally."""
setup_multimodal_config
(
mock_config
)
handler
=
PrefillHandler
(
mock_config
,
encoder_cache
=
mock_encoder_cache
)
expected_embeddings
=
[
torch
.
randn
(
10
,
256
)]
captured_embeddings
=
None
async
def
mock_generate_locally
(
request
,
context
,
embeddings
,
ep_params
):
nonlocal
captured_embeddings
captured_embeddings
=
embeddings
yield
{
"result"
:
"mock"
}
with
patch
(
"dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder"
,
new_callable
=
AsyncMock
,
return_value
=
expected_embeddings
,
)
as
mock_fetch
:
with
patch
.
object
(
handler
,
"generate_locally"
,
mock_generate_locally
):
async
for
_
in
handler
.
generate
(
image_request
,
mock_context
):
pass
mock_fetch
.
assert_called_once
()
assert
captured_embeddings
is
expected_embeddings
@
pytest
.
mark
.
asyncio
async
def
test_disaggregated_params_passed_to_generate_locally
(
self
,
mock_config
,
mock_context
,
image_request
):
"""Test DisaggregatedParams from fetch_embeddings_from_encoder passed to generate_locally."""
setup_multimodal_config
(
mock_config
)
handler
=
PrefillHandler
(
mock_config
,
encoder_cache
=
None
)
expected_params
=
DisaggregatedParams
(
request_type
=
"context_only"
)
captured_ep_params
=
None
async
def
mock_generate_locally
(
request
,
context
,
embeddings
,
ep_params
):
nonlocal
captured_ep_params
captured_ep_params
=
ep_params
yield
{
"result"
:
"mock"
}
with
patch
(
"dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder"
,
new_callable
=
AsyncMock
,
return_value
=
expected_params
,
)
as
mock_fetch
:
with
patch
.
object
(
handler
,
"generate_locally"
,
mock_generate_locally
):
async
for
_
in
handler
.
generate
(
image_request
,
mock_context
):
pass
mock_fetch
.
assert_called_once
()
assert
captured_ep_params
is
expected_params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment