Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
a78a4260
Unverified
Commit
a78a4260
authored
Feb 04, 2026
by
Qi Wang
Committed by
GitHub
Feb 04, 2026
Browse files
feat: use encoder cache in TRT-LLM EPD workflow (#5780)
parent
b12e6710
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
507 additions
and
65 deletions
+507
-65
components/src/dynamo/trtllm/multimodal/__init__.py
components/src/dynamo/trtllm/multimodal/__init__.py
+2
-0
components/src/dynamo/trtllm/multimodal/embedding_fetcher.py
components/src/dynamo/trtllm/multimodal/embedding_fetcher.py
+255
-0
components/src/dynamo/trtllm/request_handlers/handlers.py
components/src/dynamo/trtllm/request_handlers/handlers.py
+14
-63
components/src/dynamo/trtllm/tests/multimodal/test_trtllm_embedding_fetcher.py
.../trtllm/tests/multimodal/test_trtllm_embedding_fetcher.py
+134
-0
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_prefill_handler.py
...llm/tests/request_handlers/test_trtllm_prefill_handler.py
+102
-2
No files found.
components/src/dynamo/trtllm/multimodal/__init__.py
View file @
a78a4260
...
...
@@ -2,9 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
from
.cuda_ipc
import
extract_embeddings_from_handles
from
.embedding_fetcher
import
fetch_embeddings_from_encoder
from
.hasher
import
MultimodalHasher
__all__
=
[
"MultimodalHasher"
,
"extract_embeddings_from_handles"
,
"fetch_embeddings_from_encoder"
,
]
components/src/dynamo/trtllm/multimodal/embedding_fetcher.py
0 → 100644
View file @
a78a4260
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Embedding fetcher utilities for multimodal processing with caching.
Provides utility functions for fetching image embeddings from remote encoder
with per-URL caching support.
"""
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo.common.multimodal.async_encoder_cache
import
EncoderCacheManager
from
dynamo.trtllm.multimodal.cuda_ipc
import
extract_embeddings_from_handles
from
dynamo.trtllm.multimodal.hasher
import
MultimodalHasher
logger
=
logging
.
getLogger
(
__name__
)
async
def
fetch_embeddings_from_encoder
(
image_urls
:
List
[
str
],
request
:
Dict
[
str
,
Any
],
encode_client
:
Any
,
encoder_cache
:
Optional
[
EncoderCacheManager
]
=
None
,
)
->
Union
[
List
[
torch
.
Tensor
],
DisaggregatedParams
]:
"""
Fetch embeddings from remote encode worker.
Args:
image_urls: List of image URLs to encode (must not be empty)
request: Request dict (used for creating modified requests for caching)
encode_client: Client to call remote encode worker
encoder_cache: Optional cache for embeddings
Returns:
- List[torch.Tensor]: When using cache (CPU tensors from cache)
- DisaggregatedParams: When not using cache (contains CUDA IPC handles)
Raises:
ValueError: If image_urls is empty
"""
if
not
image_urls
:
raise
ValueError
(
"image_urls must not be empty"
)
logger
.
info
(
f
"fetch_embeddings_from_encoder: image_urls=
{
image_urls
}
"
)
if
encoder_cache
:
# Cache path: extract embeddings to CPU tensors
return
await
_fetch_embeddings_with_cache
(
image_urls
,
request
,
encoder_cache
,
lambda
req
:
_remote_encode_full_epd
(
req
,
encode_client
,
update_request_for_decode
=
False
),
)
else
:
# No cache: return DisaggregatedParams directly (no GPU→CPU extraction)
return
await
_remote_encode_full_epd
(
request
,
encode_client
,
update_request_for_decode
=
True
)
async
def
_remote_encode_full_epd
(
request
:
Dict
[
str
,
Any
],
encode_client
:
Any
,
update_request_for_decode
:
bool
=
True
,
)
->
DisaggregatedParams
:
"""
Call encode worker for full EPD flow.
Args:
request: Request dict
encode_client: Client to call remote encode worker
update_request_for_decode: If True, store EPD metadata in request
Returns:
DisaggregatedParams with multimodal_embedding_handles
Raises:
RuntimeError: If encode worker returns invalid response
"""
encode_response
=
None
async
for
res
in
await
encode_client
.
round_robin
(
request
):
encode_response
=
res
.
data
()
break
if
not
encode_response
:
raise
RuntimeError
(
"Did not receive a response from the encode worker."
)
if
"ep_disaggregated_params"
not
in
encode_response
:
raise
RuntimeError
(
"Encode response missing ep_disaggregated_params."
)
params_dict
=
encode_response
[
"ep_disaggregated_params"
]
if
params_dict
is
None
:
raise
RuntimeError
(
"ep_disaggregated_params is None."
)
# Store EPD metadata in request for decode worker (only when not using cache)
if
update_request_for_decode
:
if
"processed_prompt"
in
encode_response
:
request
[
"_epd_processed_prompt"
]
=
encode_response
[
"processed_prompt"
]
if
"prompt_token_ids"
in
encode_response
:
request
[
"_epd_prompt_token_ids"
]
=
encode_response
[
"prompt_token_ids"
]
return
DisaggregatedParams
(
**
params_dict
)
async
def
_fetch_embeddings_with_cache
(
image_urls
:
List
[
str
],
request
:
Dict
[
str
,
Any
],
cache
:
EncoderCacheManager
,
encode_fn
:
Callable
[[
Dict
[
str
,
Any
]],
DisaggregatedParams
],
)
->
List
[
torch
.
Tensor
]:
"""
Encode image URLs with per-URL caching and partial cache usage.
Checks cache for each URL. Cached embeddings are reused directly.
For uncached URLs, sends a single encode request for only those URLs,
then caches the results.
Args:
image_urls: List of image URLs to encode
request: Original request dict containing the images
cache: AsyncEncoderCache instance for caching embeddings
encode_fn: Async function that encodes a request and returns ep_disaggregated_params
Should accept a modified request dict with subset of URLs
Returns:
List of embedding tensors for all images in original order
"""
if
not
image_urls
:
raise
ValueError
(
"image_urls list is empty"
)
# Check cache for each URL
embeddings_with_index
=
[]
# List of (original_index, tensor)
uncached_urls
=
[]
uncached_indices
=
[]
uncached_hashes
=
[]
for
i
,
url
in
enumerate
(
image_urls
):
url_hash
=
MultimodalHasher
.
hash_bytes
(
url
.
encode
())
cached
=
cache
.
get
(
url_hash
)
if
cached
is
not
None
:
logger
.
info
(
f
"fetch_embeddings_with_cache: cache hit for URL:
{
url
}
"
)
embeddings_with_index
.
append
((
i
,
cached
))
else
:
logger
.
info
(
f
"fetch_embeddings_with_cache: cache miss for URL:
{
url
}
"
)
uncached_urls
.
append
(
url
)
uncached_indices
.
append
(
i
)
uncached_hashes
.
append
(
url_hash
)
# If all cached, return immediately
if
not
uncached_urls
:
logger
.
info
(
f
"fetch_embeddings_with_cache: all
{
len
(
image_urls
)
}
URLs cached"
)
embeddings_with_index
.
sort
(
key
=
lambda
x
:
x
[
0
])
tensors
=
[
t
for
_
,
t
in
embeddings_with_index
]
return
tensors
# Encode uncached URLs
logger
.
info
(
f
"fetch_embeddings_with_cache: encoding
{
len
(
uncached_urls
)
}
uncached URLs"
)
# Create modified request with only uncached URLs
modified_request
=
_create_request_with_urls
(
request
,
uncached_urls
)
# Call encode function
ep_disaggregated_params
=
await
encode_fn
(
modified_request
)
if
not
ep_disaggregated_params
:
raise
RuntimeError
(
"fetch_embeddings_with_cache: Failed to get ep_disaggregated_params"
)
# Extract handles from disaggregated params
handles
=
getattr
(
ep_disaggregated_params
,
"multimodal_embedding_handles"
,
None
)
if
not
handles
:
raise
RuntimeError
(
"fetch_embeddings_with_cache: No multimodal_embedding_handles in ep_disaggregated_params"
)
# Extract tensors from CUDA IPC handles
new_tensors
=
await
extract_embeddings_from_handles
(
handles
)
# Cache new tensors (reuse hashes computed during cache lookup)
for
url
,
url_hash
,
tensor
in
zip
(
uncached_urls
,
uncached_hashes
,
new_tensors
):
cache
.
set
(
url_hash
,
tensor
)
logger
.
info
(
f
"fetch_embeddings_with_cache: cached embedding for URL:
{
url
}
, shape:
{
tensor
.
shape
}
"
)
# Add new tensors to our list with their original indices
for
idx
,
tensor
in
zip
(
uncached_indices
,
new_tensors
):
embeddings_with_index
.
append
((
idx
,
tensor
))
# Sort by original order and return list
embeddings_with_index
.
sort
(
key
=
lambda
x
:
x
[
0
])
tensors
=
[
t
for
_
,
t
in
embeddings_with_index
]
return
tensors
def
_create_request_with_urls
(
original_request
:
Dict
[
str
,
Any
],
image_urls
:
List
[
str
]
)
->
Dict
[
str
,
Any
]:
"""
Create a modified request containing only specified image URLs.
Args:
original_request: Original request dict
image_urls: URLs to include in the modified request
Returns:
Modified request dict with filtered image URLs
"""
# Deep copy to avoid modifying original
import
copy
modified_request
=
copy
.
deepcopy
(
original_request
)
# Extract messages
messages
=
modified_request
.
get
(
"extra_args"
,
{}).
get
(
"messages"
,
modified_request
.
get
(
"messages"
,
[])
)
# Filter messages to only include specified URLs
filtered_messages
=
[]
for
message
in
messages
:
new_message
=
{
"role"
:
message
.
get
(
"role"
,
"user"
),
"content"
:
[]}
for
content
in
message
.
get
(
"content"
,
[]):
if
isinstance
(
content
,
dict
):
if
content
.
get
(
"type"
)
==
"image_url"
:
# Only include if URL is in our list
url
=
content
.
get
(
"image_url"
,
{}).
get
(
"url"
)
if
url
in
image_urls
:
new_message
[
"content"
].
append
(
content
)
elif
content
.
get
(
"type"
)
==
"text"
:
# Keep text content
new_message
[
"content"
].
append
(
content
)
elif
isinstance
(
content
,
str
):
new_message
[
"content"
].
append
(
content
)
if
new_message
[
"content"
]:
filtered_messages
.
append
(
new_message
)
# Update the request with filtered messages
if
"extra_args"
in
modified_request
:
modified_request
[
"extra_args"
][
"messages"
]
=
filtered_messages
else
:
modified_request
[
"messages"
]
=
filtered_messages
return
modified_request
components/src/dynamo/trtllm/request_handlers/handlers.py
View file @
a78a4260
...
...
@@ -4,12 +4,11 @@
import
logging
from
typing
import
Optional
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo._core
import
Context
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.trtllm.encode_helper
import
EncodeHelper
from
dynamo.trtllm.multimodal.embedding_fetcher
import
fetch_embeddings_from_encoder
from
dynamo.trtllm.request_handlers.handler_base
import
(
HandlerBase
,
RequestHandlerConfig
,
...
...
@@ -109,29 +108,6 @@ class PrefillHandler(HandlerBase):
super
().
__init__
(
config
)
self
.
_encoder_cache
=
encoder_cache
async
def
remote_encode_full_epd
(
self
,
request
:
dict
):
"""
Call encode worker for full EPD flow and unpack the response.
Args:
request: Request dict
Returns:
Encoder's DisaggregatedParams to be used by the prefill worker
"""
encode_response
=
None
async
for
res
in
await
self
.
encode_client
.
round_robin
(
request
):
encode_response
=
res
.
data
()
break
if
not
encode_response
:
raise
RuntimeError
(
"Did not receive a response from the encode worker."
)
ep_disaggregated_params
=
self
.
_unpack_full_epd_response
(
encode_response
,
request
)
return
ep_disaggregated_params
async
def
remote_encode_with_nixl
(
self
,
request
:
dict
):
"""
Call encode worker for NIXL flow to load embeddings and unpack the response.
...
...
@@ -156,43 +132,6 @@ class PrefillHandler(HandlerBase):
encode_response
,
self
.
connector
)
def
_unpack_full_epd_response
(
self
,
encode_response
:
dict
,
request
:
dict
)
->
Optional
[
DisaggregatedParams
]:
"""
Unpack encode worker response from full EPD flow.
Extracts DisaggregatedParams and stores EPD metadata in the request
for downstream processing (multimodal_processor, decode worker).
Args:
encode_response: Response dict from encode worker
request: Request dict to store metadata in (modified in-place)
Returns:
DisaggregatedParams if present in response, None otherwise
"""
if
"ep_disaggregated_params"
not
in
encode_response
:
return
None
params_dict
=
encode_response
[
"ep_disaggregated_params"
]
if
params_dict
is
None
:
return
None
# Reconstruct DisaggregatedParams object from dict
ep_disaggregated_params
=
DisaggregatedParams
(
**
params_dict
)
ep_disaggregated_params
.
request_type
=
"context_only"
# Store processed prompt from encoder (includes <image> tokens)
if
"processed_prompt"
in
encode_response
:
request
[
"_epd_processed_prompt"
]
=
encode_response
[
"processed_prompt"
]
# Store prompt_token_ids from encoder for decode worker
if
"prompt_token_ids"
in
encode_response
:
request
[
"_epd_prompt_token_ids"
]
=
encode_response
[
"prompt_token_ids"
]
return
ep_disaggregated_params
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
"""
Prefill worker: process prompt and return disaggregated_params.
...
...
@@ -230,7 +169,19 @@ class PrefillHandler(HandlerBase):
# Handle image URLs (full E-PD flow with MultimodalEncoder)
elif
image_urls
:
if
self
.
encode_client
:
ep_disaggregated_params
=
await
self
.
remote_encode_full_epd
(
request
)
logging
.
info
(
f
"PrefillHandler: image_urls=
{
image_urls
}
"
)
result
=
await
fetch_embeddings_from_encoder
(
image_urls
,
request
,
self
.
encode_client
,
self
.
_encoder_cache
,
)
if
isinstance
(
result
,
list
):
# Cache path: got List[torch.Tensor]
embeddings_tensor
=
result
else
:
# No-cache path: got DisaggregatedParams
ep_disaggregated_params
=
result
# Normal flow: Generate the prefill response locally with embeddings
response_count
=
0
...
...
components/src/dynamo/trtllm/tests/multimodal/test_trtllm_embedding_fetcher.py
0 → 100644
View file @
a78a4260
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for fetch_embeddings_from_encoder."""
from
typing
import
Any
from
unittest.mock
import
AsyncMock
,
patch
import
pytest
import
torch
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
from
dynamo.trtllm.multimodal.embedding_fetcher
import
fetch_embeddings_from_encoder
from
dynamo.trtllm.multimodal.hasher
import
MultimodalHasher
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
trtllm
,
pytest
.
mark
.
gpu_0
,
]
def
create_mock_encode_client
(
embeddings
:
list
[
torch
.
Tensor
],
processed_prompt
:
str
=
"prompt"
,
prompt_token_ids
:
list
[
int
]
|
None
=
None
,
)
->
AsyncMock
:
"""Create mock encode client that returns embeddings via CUDA IPC handles."""
class
MockResponse
:
def
data
(
self
):
return
{
"ep_disaggregated_params"
:
{
"multimodal_embedding_handles"
:
[
f
"h
{
i
}
"
for
i
in
range
(
len
(
embeddings
))
]
},
"processed_prompt"
:
processed_prompt
,
"prompt_token_ids"
:
prompt_token_ids
or
[
1
,
2
,
3
],
}
async
def
mock_round_robin
(
req
:
dict
[
str
,
Any
])
->
Any
:
async
def
gen
():
yield
MockResponse
()
return
gen
()
client
=
AsyncMock
()
client
.
round_robin
=
mock_round_robin
return
client
@
pytest
.
fixture
def
encoder_cache
()
->
EncoderCacheManager
:
"""Create encoder cache with 10MB capacity."""
return
EncoderCacheManager
(
capacity_bytes
=
10
*
1024
*
1024
)
class
TestFetchEmbeddingsFromEncoder
:
"""Tests for fetch_embeddings_from_encoder function."""
@
pytest
.
mark
.
asyncio
async
def
test_partial_cache_no_metadata_update
(
self
,
encoder_cache
):
"""Cache path: request NOT updated with EPD metadata."""
url1
,
url2
=
"http://example.com/img1.jpg"
,
"http://example.com/img2.jpg"
embedding1
,
embedding2
=
torch
.
ones
(
10
,
256
),
torch
.
ones
(
10
,
256
)
*
2
encoder_cache
.
set
(
MultimodalHasher
.
hash_bytes
(
url1
.
encode
()),
embedding1
)
request
:
dict
[
str
,
Any
]
=
{
"messages"
:
[]}
mock_client
=
create_mock_encode_client
([
embedding2
])
with
patch
(
"dynamo.trtllm.multimodal.embedding_fetcher.extract_embeddings_from_handles"
,
AsyncMock
(
return_value
=
[
embedding2
]),
):
result
=
await
fetch_embeddings_from_encoder
(
[
url1
,
url2
],
request
,
mock_client
,
encoder_cache
)
assert
len
(
result
)
==
2
assert
"_epd_processed_prompt"
not
in
request
@
pytest
.
mark
.
asyncio
async
def
test_all_cached_no_request_sent
(
self
,
encoder_cache
):
"""All cached: no encode request sent."""
url1
,
url2
=
"http://example.com/img1.jpg"
,
"http://example.com/img2.jpg"
embedding1
,
embedding2
=
torch
.
ones
(
10
,
256
),
torch
.
ones
(
10
,
256
)
*
2
encoder_cache
.
set
(
MultimodalHasher
.
hash_bytes
(
url1
.
encode
()),
embedding1
)
encoder_cache
.
set
(
MultimodalHasher
.
hash_bytes
(
url2
.
encode
()),
embedding2
)
async
def
should_not_call
(
req
:
dict
[
str
,
Any
])
->
None
:
raise
AssertionError
(
"Should not be called"
)
mock_client
=
AsyncMock
()
mock_client
.
round_robin
=
should_not_call
result
=
await
fetch_embeddings_from_encoder
(
[
url1
,
url2
],
{
"messages"
:
[]},
mock_client
,
encoder_cache
)
assert
len
(
result
)
==
2
assert
torch
.
equal
(
result
[
0
],
embedding1
)
@
pytest
.
mark
.
asyncio
async
def
test_no_cache_returns_disaggregated_params
(
self
):
"""No cache: returns DisaggregatedParams directly, request updated with metadata."""
request
:
dict
[
str
,
Any
]
=
{
"messages"
:
[]}
# Pass one embedding so mock generates one handle (DisaggregatedParams requires non-empty handles)
mock_client
=
create_mock_encode_client
(
[
torch
.
ones
(
10
,
256
)],
processed_prompt
=
"test <image>"
,
prompt_token_ids
=
[
10
,
20
],
)
result
=
await
fetch_embeddings_from_encoder
(
[
"http://example.com/img.jpg"
],
request
,
mock_client
,
encoder_cache
=
None
)
assert
isinstance
(
result
,
DisaggregatedParams
)
assert
request
[
"_epd_processed_prompt"
]
==
"test <image>"
assert
request
[
"_epd_prompt_token_ids"
]
==
[
10
,
20
]
@
pytest
.
mark
.
asyncio
async
def
test_empty_urls_raises_error
(
self
,
encoder_cache
):
"""Empty image_urls raises ValueError."""
mock_client
=
AsyncMock
()
with
pytest
.
raises
(
ValueError
,
match
=
"image_urls must not be empty"
):
await
fetch_embeddings_from_encoder
([],
{},
mock_client
,
encoder_cache
)
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_prefill_handler.py
View file @
a78a4260
...
...
@@ -3,15 +3,19 @@
"""Unit tests for PrefillHandler."""
from
unittest.mock
import
MagicMock
from
typing
import
Any
from
unittest.mock
import
AsyncMock
,
MagicMock
,
patch
import
pytest
import
torch
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo.trtllm.request_handlers.handlers
import
PrefillHandler
from
dynamo.trtllm.tests.utils
import
create_mock_request_handler_config
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
trtllm
,
pytest
.
mark
.
gpu_0
,
]
...
...
@@ -29,10 +33,46 @@ def mock_encoder_cache():
cache
=
MagicMock
()
cache
.
get
=
MagicMock
(
return_value
=
None
)
cache
.
set
=
MagicMock
(
return_value
=
True
)
cache
.
stats
=
{
"hits"
:
0
,
"misses"
:
0
,
"entries"
:
0
}
return
cache
@
pytest
.
fixture
def
mock_context
():
"""Create a mock Context."""
ctx
=
MagicMock
()
ctx
.
id
=
MagicMock
(
return_value
=
"test-id"
)
ctx
.
is_stopped
=
MagicMock
(
return_value
=
False
)
ctx
.
is_killed
=
MagicMock
(
return_value
=
False
)
return
ctx
@
pytest
.
fixture
def
image_request
()
->
dict
[
str
,
Any
]:
"""Create a request with one image URL."""
return
{
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"http://example.com/image.jpg"
},
},
],
}
]
}
def
setup_multimodal_config
(
mock_config
):
"""Configure mock_config for multimodal requests."""
mock_config
.
multimodal_processor
=
MagicMock
()
mock_config
.
multimodal_processor
.
extract_prompt_and_media
=
MagicMock
(
return_value
=
(
"text"
,
[
"http://example.com/image.jpg"
],
[])
)
mock_config
.
encode_client
=
MagicMock
()
class
TestPrefillHandlerInit
:
"""Tests for PrefillHandler initialization."""
...
...
@@ -42,3 +82,63 @@ class TestPrefillHandlerInit:
assert
handler
.
engine
==
mock_config
.
engine
assert
handler
.
_encoder_cache
==
mock_encoder_cache
class
TestPrefillHandlerGenerate
:
"""Tests for PrefillHandler.generate method."""
@
pytest
.
mark
.
asyncio
async
def
test_embeddings_passed_to_generate_locally
(
self
,
mock_config
,
mock_encoder_cache
,
mock_context
,
image_request
):
"""Test embeddings from fetch_embeddings_from_encoder passed to generate_locally."""
setup_multimodal_config
(
mock_config
)
handler
=
PrefillHandler
(
mock_config
,
encoder_cache
=
mock_encoder_cache
)
expected_embeddings
=
[
torch
.
randn
(
10
,
256
)]
captured_embeddings
=
None
async
def
mock_generate_locally
(
request
,
context
,
embeddings
,
ep_params
):
nonlocal
captured_embeddings
captured_embeddings
=
embeddings
yield
{
"result"
:
"mock"
}
with
patch
(
"dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder"
,
new_callable
=
AsyncMock
,
return_value
=
expected_embeddings
,
)
as
mock_fetch
:
with
patch
.
object
(
handler
,
"generate_locally"
,
mock_generate_locally
):
async
for
_
in
handler
.
generate
(
image_request
,
mock_context
):
pass
mock_fetch
.
assert_called_once
()
assert
captured_embeddings
is
expected_embeddings
@
pytest
.
mark
.
asyncio
async
def
test_disaggregated_params_passed_to_generate_locally
(
self
,
mock_config
,
mock_context
,
image_request
):
"""Test DisaggregatedParams from fetch_embeddings_from_encoder passed to generate_locally."""
setup_multimodal_config
(
mock_config
)
handler
=
PrefillHandler
(
mock_config
,
encoder_cache
=
None
)
expected_params
=
DisaggregatedParams
(
request_type
=
"context_only"
)
captured_ep_params
=
None
async
def
mock_generate_locally
(
request
,
context
,
embeddings
,
ep_params
):
nonlocal
captured_ep_params
captured_ep_params
=
ep_params
yield
{
"result"
:
"mock"
}
with
patch
(
"dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder"
,
new_callable
=
AsyncMock
,
return_value
=
expected_params
,
)
as
mock_fetch
:
with
patch
.
object
(
handler
,
"generate_locally"
,
mock_generate_locally
):
async
for
_
in
handler
.
generate
(
image_request
,
mock_context
):
pass
mock_fetch
.
assert_called_once
()
assert
captured_ep_params
is
expected_params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment