Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
b82b45a1
Unverified
Commit
b82b45a1
authored
Feb 03, 2026
by
Qi Wang
Committed by
GitHub
Feb 04, 2026
Browse files
feat: add EncoderCacheManager to TRT-LLM PrefillHandler (#5714)
parent
0268aea4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
182 additions
and
25 deletions
+182
-25
components/src/dynamo/common/memory/encoder_cache_manager.py
components/src/dynamo/common/memory/encoder_cache_manager.py
+0
-3
components/src/dynamo/common/tests/memory/test_encoder_cache_manager.py
.../dynamo/common/tests/memory/test_encoder_cache_manager.py
+0
-21
components/src/dynamo/trtllm/main.py
components/src/dynamo/trtllm/main.py
+1
-0
components/src/dynamo/trtllm/request_handlers/handler_base.py
...onents/src/dynamo/trtllm/request_handlers/handler_base.py
+1
-0
components/src/dynamo/trtllm/request_handlers/handlers.py
components/src/dynamo/trtllm/request_handlers/handlers.py
+14
-1
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_prefill_handler.py
...llm/tests/request_handlers/test_trtllm_prefill_handler.py
+44
-0
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_request_handler_factory.py
...s/request_handlers/test_trtllm_request_handler_factory.py
+77
-0
components/src/dynamo/trtllm/tests/utils.py
components/src/dynamo/trtllm/tests/utils.py
+36
-0
components/src/dynamo/trtllm/utils/trtllm_utils.py
components/src/dynamo/trtllm/utils/trtllm_utils.py
+9
-0
No files found.
components/src/dynamo/common/memory/encoder_cache_manager.py
View file @
b82b45a1
...
@@ -47,9 +47,6 @@ class EncoderCacheManager:
...
@@ -47,9 +47,6 @@ class EncoderCacheManager:
Args:
Args:
capacity_bytes: Maximum cache capacity in bytes.
capacity_bytes: Maximum cache capacity in bytes.
"""
"""
if
capacity_bytes
<=
0
:
raise
ValueError
(
"capacity_bytes must be positive"
)
self
.
_cache
:
OrderedDict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
self
.
_cache
:
OrderedDict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
self
.
_capacity_bytes
=
capacity_bytes
self
.
_capacity_bytes
=
capacity_bytes
self
.
_current_bytes
=
0
self
.
_current_bytes
=
0
...
...
components/src/dynamo/common/tests/memory/test_encoder_cache_manager.py
View file @
b82b45a1
...
@@ -9,27 +9,6 @@ import torch
...
@@ -9,27 +9,6 @@ import torch
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
class
TestEncoderCacheManagerInit
:
"""Tests for initialization."""
def
test_init_valid_capacity
(
self
):
"""Test initialization with valid capacity."""
cache
=
EncoderCacheManager
(
capacity_bytes
=
1024
)
assert
cache
.
stats
[
"capacity_bytes"
]
==
1024
assert
cache
.
stats
[
"current_bytes"
]
==
0
assert
cache
.
stats
[
"entries"
]
==
0
def
test_init_invalid_capacity_zero
(
self
):
"""Test initialization with zero capacity raises error."""
with
pytest
.
raises
(
ValueError
,
match
=
"capacity_bytes must be positive"
):
EncoderCacheManager
(
capacity_bytes
=
0
)
def
test_init_invalid_capacity_negative
(
self
):
"""Test initialization with negative capacity raises error."""
with
pytest
.
raises
(
ValueError
,
match
=
"capacity_bytes must be positive"
):
EncoderCacheManager
(
capacity_bytes
=-
100
)
class
TestEncoderCacheManagerBasicOperations
:
class
TestEncoderCacheManagerBasicOperations
:
"""Tests for basic get/set operations."""
"""Tests for basic get/set operations."""
...
...
components/src/dynamo/trtllm/main.py
View file @
b82b45a1
...
@@ -441,6 +441,7 @@ async def init(
...
@@ -441,6 +441,7 @@ async def init(
metrics_collector
=
metrics_collector
,
metrics_collector
=
metrics_collector
,
kv_block_size
=
config
.
kv_block_size
,
kv_block_size
=
config
.
kv_block_size
,
shutdown_event
=
shutdown_event
,
shutdown_event
=
shutdown_event
,
encoder_cache_capacity_gb
=
config
.
encoder_cache_capacity_gb
,
)
)
# Register the model with runtime config
# Register the model with runtime config
...
...
components/src/dynamo/trtllm/request_handlers/handler_base.py
View file @
b82b45a1
...
@@ -68,6 +68,7 @@ class RequestHandlerConfig:
...
@@ -68,6 +68,7 @@ class RequestHandlerConfig:
metrics_collector
:
Optional
[
Any
]
=
None
# TensorRT-LLM MetricsCollector
metrics_collector
:
Optional
[
Any
]
=
None
# TensorRT-LLM MetricsCollector
kv_block_size
:
int
=
32
kv_block_size
:
int
=
32
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
encoder_cache_capacity_gb
:
float
=
0
# Encoder cache capacity in GB
class
HandlerBase
:
class
HandlerBase
:
...
...
components/src/dynamo/trtllm/request_handlers/handlers.py
View file @
b82b45a1
...
@@ -7,6 +7,7 @@ from typing import Optional
...
@@ -7,6 +7,7 @@ from typing import Optional
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo._core
import
Context
from
dynamo._core
import
Context
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.trtllm.encode_helper
import
EncodeHelper
from
dynamo.trtllm.encode_helper
import
EncodeHelper
from
dynamo.trtllm.request_handlers.handler_base
import
(
from
dynamo.trtllm.request_handlers.handler_base
import
(
...
@@ -31,6 +32,13 @@ class RequestHandlerFactory:
...
@@ -31,6 +32,13 @@ class RequestHandlerFactory:
raise
ValueError
(
raise
ValueError
(
f
"Invalid disaggregation_mode '
{
config
.
disaggregation_mode
.
value
}
'"
f
"Invalid disaggregation_mode '
{
config
.
disaggregation_mode
.
value
}
'"
)
)
if
config
.
disaggregation_mode
.
value
==
"prefill"
:
encoder_cache
=
None
if
config
.
encoder_cache_capacity_gb
>
0
:
# Create encoder cache for prefill handler
capacity_bytes
=
int
(
config
.
encoder_cache_capacity_gb
*
1024
**
3
)
encoder_cache
=
EncoderCacheManager
(
capacity_bytes
)
return
PrefillHandler
(
config
,
encoder_cache
=
encoder_cache
)
return
self
.
handlers
[
config
.
disaggregation_mode
.
value
](
config
)
return
self
.
handlers
[
config
.
disaggregation_mode
.
value
](
config
)
...
@@ -93,8 +101,13 @@ class PrefillHandler(HandlerBase):
...
@@ -93,8 +101,13 @@ class PrefillHandler(HandlerBase):
Handler for prefill-only workers in disaggregated serving.
Handler for prefill-only workers in disaggregated serving.
"""
"""
def
__init__
(
self
,
config
:
RequestHandlerConfig
):
def
__init__
(
self
,
config
:
RequestHandlerConfig
,
encoder_cache
:
Optional
[
EncoderCacheManager
]
=
None
,
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
_encoder_cache
=
encoder_cache
async
def
remote_encode_full_epd
(
self
,
request
:
dict
):
async
def
remote_encode_full_epd
(
self
,
request
:
dict
):
"""
"""
...
...
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_prefill_handler.py
0 → 100644
View file @
b82b45a1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for PrefillHandler."""
from
unittest.mock
import
MagicMock
import
pytest
from
dynamo.trtllm.request_handlers.handlers
import
PrefillHandler
from
dynamo.trtllm.tests.utils
import
create_mock_request_handler_config
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
trtllm
,
pytest
.
mark
.
gpu_0
,
]
@
pytest
.
fixture
def
mock_config
():
"""Create a mock RequestHandlerConfig."""
return
create_mock_request_handler_config
(
disaggregation_mode
=
"prefill"
)
@
pytest
.
fixture
def
mock_encoder_cache
():
"""Create a mock EncoderCacheManager."""
cache
=
MagicMock
()
cache
.
get
=
MagicMock
(
return_value
=
None
)
cache
.
set
=
MagicMock
(
return_value
=
True
)
cache
.
stats
=
{
"hits"
:
0
,
"misses"
:
0
,
"entries"
:
0
}
return
cache
class
TestPrefillHandlerInit
:
"""Tests for PrefillHandler initialization."""
def
test_init_with_encoder_cache
(
self
,
mock_config
,
mock_encoder_cache
):
"""Test PrefillHandler can be initialized with encoder_cache."""
handler
=
PrefillHandler
(
mock_config
,
encoder_cache
=
mock_encoder_cache
)
assert
handler
.
engine
==
mock_config
.
engine
assert
handler
.
_encoder_cache
==
mock_encoder_cache
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_request_handler_factory.py
0 → 100644
View file @
b82b45a1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for RequestHandlerFactory."""
import
pytest
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
from
dynamo.trtllm.request_handlers.handlers
import
(
AggregatedHandler
,
PrefillHandler
,
RequestHandlerFactory
,
)
from
dynamo.trtllm.tests.utils
import
create_mock_request_handler_config
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
trtllm
,
pytest
.
mark
.
gpu_0
,
]
@
pytest
.
fixture
def
mock_config
():
"""Create a mock RequestHandlerConfig."""
return
create_mock_request_handler_config
()
class
TestRequestHandlerFactory
:
"""Tests for RequestHandlerFactory."""
def
test_creates_aggregated_handler
(
self
,
mock_config
):
"""Test factory creates AggregatedHandler for prefill_and_decode mode."""
factory
=
RequestHandlerFactory
()
handler
=
factory
.
get_request_handler
(
mock_config
)
assert
isinstance
(
handler
,
AggregatedHandler
)
def
test_creates_prefill_handler
(
self
,
mock_config
):
"""Test factory creates PrefillHandler for prefill mode."""
mock_config
.
disaggregation_mode
.
value
=
"prefill"
factory
=
RequestHandlerFactory
()
handler
=
factory
.
get_request_handler
(
mock_config
)
assert
isinstance
(
handler
,
PrefillHandler
)
def
test_invalid_mode_raises
(
self
,
mock_config
):
"""Test factory raises ValueError for invalid disaggregation_mode."""
mock_config
.
disaggregation_mode
.
value
=
"invalid_mode"
factory
=
RequestHandlerFactory
()
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid disaggregation_mode"
):
factory
.
get_request_handler
(
mock_config
)
def
test_prefill_handler_with_encoder_cache
(
self
):
"""Test factory creates PrefillHandler with EncoderCacheManager when capacity > 0."""
mock_config
=
create_mock_request_handler_config
(
disaggregation_mode
=
"prefill"
,
encoder_cache_capacity_gb
=
1.0
,
)
factory
=
RequestHandlerFactory
()
handler
=
factory
.
get_request_handler
(
mock_config
)
assert
isinstance
(
handler
,
PrefillHandler
)
assert
isinstance
(
handler
.
_encoder_cache
,
EncoderCacheManager
)
def
test_prefill_handler_without_encoder_cache
(
self
):
"""Test factory creates PrefillHandler with no cache when capacity is 0."""
mock_config
=
create_mock_request_handler_config
(
disaggregation_mode
=
"prefill"
,
encoder_cache_capacity_gb
=
0
,
)
factory
=
RequestHandlerFactory
()
handler
=
factory
.
get_request_handler
(
mock_config
)
assert
isinstance
(
handler
,
PrefillHandler
)
assert
handler
.
_encoder_cache
is
None
components/src/dynamo/trtllm/tests/utils.py
0 → 100644
View file @
b82b45a1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared test utilities for dynamo.trtllm tests."""
from
unittest.mock
import
MagicMock
def
create_mock_request_handler_config
(
disaggregation_mode
:
str
=
"prefill_and_decode"
,
encoder_cache_capacity_gb
:
float
=
0
,
)
->
MagicMock
:
"""Create a mock RequestHandlerConfig for testing.
Args:
disaggregation_mode: The disaggregation mode value.
encoder_cache_capacity_gb: Encoder cache capacity in GB.
Returns:
MagicMock configured as a RequestHandlerConfig.
"""
config
=
MagicMock
()
config
.
disaggregation_mode
.
value
=
disaggregation_mode
config
.
engine
=
MagicMock
()
config
.
component
=
MagicMock
()
config
.
default_sampling_params
=
MagicMock
()
config
.
publisher
=
MagicMock
()
config
.
metrics_collector
=
None
config
.
encode_client
=
None
config
.
multimodal_processor
=
None
config
.
connector
=
None
config
.
runtime
=
None
config
.
kv_block_size
=
32
config
.
shutdown_event
=
None
config
.
encoder_cache_capacity_gb
=
encoder_cache_capacity_gb
return
config
components/src/dynamo/trtllm/utils/trtllm_utils.py
View file @
b82b45a1
...
@@ -54,6 +54,7 @@ class Config:
...
@@ -54,6 +54,7 @@ class Config:
self
.
modality
:
str
=
"text"
self
.
modality
:
str
=
"text"
self
.
allowed_local_media_path
:
str
=
""
self
.
allowed_local_media_path
:
str
=
""
self
.
max_file_size_mb
:
int
=
50
self
.
max_file_size_mb
:
int
=
50
self
.
encoder_cache_capacity_gb
:
float
=
0
self
.
reasoning_parser
:
Optional
[
str
]
=
None
self
.
reasoning_parser
:
Optional
[
str
]
=
None
self
.
tool_call_parser
:
Optional
[
str
]
=
None
self
.
tool_call_parser
:
Optional
[
str
]
=
None
self
.
dump_config_to
:
Optional
[
str
]
=
None
self
.
dump_config_to
:
Optional
[
str
]
=
None
...
@@ -92,6 +93,7 @@ class Config:
...
@@ -92,6 +93,7 @@ class Config:
f
"modality=
{
self
.
modality
}
, "
f
"modality=
{
self
.
modality
}
, "
f
"allowed_local_media_path=
{
self
.
allowed_local_media_path
}
, "
f
"allowed_local_media_path=
{
self
.
allowed_local_media_path
}
, "
f
"max_file_size_mb=
{
self
.
max_file_size_mb
}
, "
f
"max_file_size_mb=
{
self
.
max_file_size_mb
}
, "
f
"encoder_cache_capacity_gb=
{
self
.
encoder_cache_capacity_gb
}
, "
f
"reasoning_parser=
{
self
.
reasoning_parser
}
, "
f
"reasoning_parser=
{
self
.
reasoning_parser
}
, "
f
"tool_call_parser=
{
self
.
tool_call_parser
}
, "
f
"tool_call_parser=
{
self
.
tool_call_parser
}
, "
f
"dump_config_to=
{
self
.
dump_config_to
}
, "
f
"dump_config_to=
{
self
.
dump_config_to
}
, "
...
@@ -286,6 +288,12 @@ def cmd_line_args():
...
@@ -286,6 +288,12 @@ def cmd_line_args():
default
=
50
,
default
=
50
,
help
=
"Maximum size of downloadable embedding files/Image URLs. Default: 50MB"
,
help
=
"Maximum size of downloadable embedding files/Image URLs. Default: 50MB"
,
)
)
parser
.
add_argument
(
"--dyn-encoder-cache-capacity-gb"
,
type
=
float
,
default
=
0
,
help
=
"Capacity of the encoder cache in GB for multimodal embeddings. Default: 0"
,
)
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser
.
add_argument
(
parser
.
add_argument
(
"--dyn-tool-call-parser"
,
"--dyn-tool-call-parser"
,
...
@@ -384,6 +392,7 @@ def cmd_line_args():
...
@@ -384,6 +392,7 @@ def cmd_line_args():
config
.
encode_endpoint
=
args
.
encode_endpoint
config
.
encode_endpoint
=
args
.
encode_endpoint
config
.
allowed_local_media_path
=
args
.
allowed_local_media_path
config
.
allowed_local_media_path
=
args
.
allowed_local_media_path
config
.
max_file_size_mb
=
args
.
max_file_size_mb
config
.
max_file_size_mb
=
args
.
max_file_size_mb
config
.
encoder_cache_capacity_gb
=
args
.
dyn_encoder_cache_capacity_gb
config
.
tensor_parallel_size
=
args
.
tensor_parallel_size
config
.
tensor_parallel_size
=
args
.
tensor_parallel_size
if
args
.
pipeline_parallel_size
is
not
None
:
if
args
.
pipeline_parallel_size
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment