Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
00ea11ff
Unverified
Commit
00ea11ff
authored
Feb 06, 2026
by
Qi Wang
Committed by
GitHub
Feb 06, 2026
Browse files
feat: EC E/PD workflow in TRT-LLM (#5815)
parent
410691dc
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
344 additions
and
133 deletions
+344
-133
components/src/dynamo/trtllm/multimodal/cuda_ipc.py
components/src/dynamo/trtllm/multimodal/cuda_ipc.py
+1
-1
components/src/dynamo/trtllm/request_handlers/aggregated_handler.py
.../src/dynamo/trtllm/request_handlers/aggregated_handler.py
+63
-0
components/src/dynamo/trtllm/request_handlers/handlers.py
components/src/dynamo/trtllm/request_handlers/handlers.py
+7
-20
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_aggregated_handler.py
.../tests/request_handlers/test_trtllm_aggregated_handler.py
+67
-0
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_prefill_handler.py
...llm/tests/request_handlers/test_trtllm_prefill_handler.py
+37
-104
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_request_handler_factory.py
...s/request_handlers/test_trtllm_request_handler_factory.py
+8
-8
components/src/dynamo/trtllm/tests/request_handlers/utils.py
components/src/dynamo/trtllm/tests/request_handlers/utils.py
+73
-0
examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/agg.yaml
...s/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/agg.yaml
+29
-0
examples/backends/trtllm/launch/e_pd_disagg.sh
examples/backends/trtllm/launch/e_pd_disagg.sh
+59
-0
No files found.
components/src/dynamo/trtllm/multimodal/cuda_ipc.py
View file @
00ea11ff
...
@@ -27,7 +27,7 @@ async def extract_embeddings_from_handles(
...
@@ -27,7 +27,7 @@ async def extract_embeddings_from_handles(
properly.
properly.
Args:
Args:
handles: List of CUDA IPC handle dictionaries from encoder response
handles: List of CUDA IPC handle dictionaries from encoder response
.
Returns:
Returns:
List of embedding tensors on CPU.
List of embedding tensors on CPU.
...
...
components/src/dynamo/trtllm/request_handlers/aggregated_handler.py
0 → 100644
View file @
00ea11ff
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Handler for aggregated (prefill + decode) mode with optional encoder disaggregation."""
import
logging
from
typing
import
Optional
from
dynamo._core
import
Context
from
dynamo.common.memory.encoder_cache_manager
import
EncoderCacheManager
from
dynamo.trtllm.multimodal.embedding_fetcher
import
fetch_embeddings_from_encoder
from
dynamo.trtllm.request_handlers.handler_base
import
(
HandlerBase
,
RequestHandlerConfig
,
)
class
AggregatedHandler
(
HandlerBase
):
"""
Handler for aggregated mode (prefill + decode in single worker).
Supports optional encoder disaggregation (E_PD flow) when encode_client
and encoder_cache are configured.
"""
def
__init__
(
self
,
config
:
RequestHandlerConfig
,
encoder_cache
:
Optional
[
EncoderCacheManager
]
=
None
,
):
super
().
__init__
(
config
)
self
.
_encoder_cache
=
encoder_cache
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
"""Generate response, optionally using remote encoder for multimodal."""
logging
.
debug
(
f
"AggregatedHandler Request ID:
{
context
.
id
()
}
"
)
embeddings
=
None
ep_disaggregated_params
=
None
if
self
.
multimodal_processor
and
self
.
encode_client
:
messages
=
request
.
get
(
"extra_args"
,
{}).
get
(
"messages"
,
request
.
get
(
"messages"
,
[])
)
_
,
image_urls
,
_
=
self
.
multimodal_processor
.
extract_prompt_and_media
(
messages
)
if
image_urls
:
logging
.
info
(
f
"AggregatedHandler: image_urls=
{
image_urls
}
"
)
result
=
await
fetch_embeddings_from_encoder
(
image_urls
,
request
,
self
.
encode_client
,
self
.
_encoder_cache
,
)
if
isinstance
(
result
,
list
):
embeddings
=
result
else
:
ep_disaggregated_params
=
result
async
for
res
in
self
.
generate_locally
(
request
,
context
,
embeddings
,
ep_disaggregated_params
):
yield
res
components/src/dynamo/trtllm/request_handlers/handlers.py
View file @
00ea11ff
...
@@ -9,6 +9,7 @@ from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
...
@@ -9,6 +9,7 @@ from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.trtllm.encode_helper
import
EncodeHelper
from
dynamo.trtllm.encode_helper
import
EncodeHelper
from
dynamo.trtllm.multimodal.embedding_fetcher
import
fetch_embeddings_from_encoder
from
dynamo.trtllm.multimodal.embedding_fetcher
import
fetch_embeddings_from_encoder
from
dynamo.trtllm.request_handlers.aggregated_handler
import
AggregatedHandler
from
dynamo.trtllm.request_handlers.handler_base
import
(
from
dynamo.trtllm.request_handlers.handler_base
import
(
HandlerBase
,
HandlerBase
,
RequestHandlerConfig
,
RequestHandlerConfig
,
...
@@ -31,13 +32,14 @@ class RequestHandlerFactory:
...
@@ -31,13 +32,14 @@ class RequestHandlerFactory:
raise
ValueError
(
raise
ValueError
(
f
"Invalid disaggregation_mode '
{
config
.
disaggregation_mode
.
value
}
'"
f
"Invalid disaggregation_mode '
{
config
.
disaggregation_mode
.
value
}
'"
)
)
encoder_cache
=
None
if
config
.
encoder_cache_capacity_gb
>
0
:
capacity_bytes
=
int
(
config
.
encoder_cache_capacity_gb
*
1024
**
3
)
encoder_cache
=
EncoderCacheManager
(
capacity_bytes
)
if
config
.
disaggregation_mode
.
value
==
"prefill"
:
if
config
.
disaggregation_mode
.
value
==
"prefill"
:
encoder_cache
=
None
if
config
.
encoder_cache_capacity_gb
>
0
:
# Create encoder cache for prefill handler
capacity_bytes
=
int
(
config
.
encoder_cache_capacity_gb
*
1024
**
3
)
encoder_cache
=
EncoderCacheManager
(
capacity_bytes
)
return
PrefillHandler
(
config
,
encoder_cache
=
encoder_cache
)
return
PrefillHandler
(
config
,
encoder_cache
=
encoder_cache
)
if
config
.
disaggregation_mode
.
value
==
"prefill_and_decode"
:
return
AggregatedHandler
(
config
,
encoder_cache
=
encoder_cache
)
return
self
.
handlers
[
config
.
disaggregation_mode
.
value
](
config
)
return
self
.
handlers
[
config
.
disaggregation_mode
.
value
](
config
)
...
@@ -45,21 +47,6 @@ def get_request_handler(config: RequestHandlerConfig) -> HandlerBase:
...
@@ -45,21 +47,6 @@ def get_request_handler(config: RequestHandlerConfig) -> HandlerBase:
return
RequestHandlerFactory
().
get_request_handler
(
config
)
return
RequestHandlerFactory
().
get_request_handler
(
config
)
class
AggregatedHandler
(
HandlerBase
):
"""
Handler for the aggregated mode.
"""
def
__init__
(
self
,
config
:
RequestHandlerConfig
):
super
().
__init__
(
config
)
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
logging
.
debug
(
f
"New Request ID:
{
context
.
id
()
}
"
)
# Implement all steps locally.
async
for
res
in
self
.
generate_locally
(
request
,
context
):
yield
res
class
EncodeHandler
(
HandlerBase
):
class
EncodeHandler
(
HandlerBase
):
"""
"""
Handler for the encode mode.
Handler for the encode mode.
...
...
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_aggregated_handler.py
0 → 100644
View file @
00ea11ff
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for AggregatedHandler."""
import
pytest
import
torch
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo.trtllm.request_handlers.aggregated_handler
import
AggregatedHandler
from
dynamo.trtllm.tests.request_handlers.utils
import
(
create_mock_encoder_cache
,
run_generate_with_mock_fetch
,
setup_multimodal_config
,
)
from
dynamo.trtllm.tests.utils
import
create_mock_request_handler_config
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
trtllm
,
pytest
.
mark
.
gpu_0
,
]
FETCH_PATCH_PATH
=
(
"dynamo.trtllm.request_handlers.aggregated_handler.fetch_embeddings_from_encoder"
)
class
TestAggregatedHandlerGenerate
:
"""Tests for AggregatedHandler.generate method."""
@
pytest
.
mark
.
asyncio
async
def
test_embeddings_passed_to_generate_locally
(
self
):
"""Cache path: List[Tensor] passed as embeddings."""
config
=
create_mock_request_handler_config
(
disaggregation_mode
=
"prefill_and_decode"
)
setup_multimodal_config
(
config
,
[
"http://example.com/image.jpg"
])
handler
=
AggregatedHandler
(
config
,
encoder_cache
=
create_mock_encoder_cache
())
expected_embeddings
=
[
torch
.
randn
(
10
,
256
)]
embeddings
,
ep_params
=
await
run_generate_with_mock_fetch
(
handler
,
FETCH_PATCH_PATH
,
expected_embeddings
)
assert
embeddings
is
expected_embeddings
assert
ep_params
is
None
@
pytest
.
mark
.
asyncio
async
def
test_disaggregated_params_passed_to_generate_locally
(
self
):
"""No-cache path: DisaggregatedParams passed as ep_params."""
config
=
create_mock_request_handler_config
(
disaggregation_mode
=
"prefill_and_decode"
)
setup_multimodal_config
(
config
,
[
"http://example.com/image.jpg"
])
handler
=
AggregatedHandler
(
config
,
encoder_cache
=
None
)
expected_params
=
DisaggregatedParams
(
request_type
=
"context_only"
)
embeddings
,
ep_params
=
await
run_generate_with_mock_fetch
(
handler
,
FETCH_PATCH_PATH
,
expected_params
)
assert
embeddings
is
None
assert
ep_params
is
expected_params
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_prefill_handler.py
View file @
00ea11ff
...
@@ -3,14 +3,16 @@
...
@@ -3,14 +3,16 @@
"""Unit tests for PrefillHandler."""
"""Unit tests for PrefillHandler."""
from
typing
import
Any
from
unittest.mock
import
AsyncMock
,
MagicMock
,
patch
import
pytest
import
pytest
import
torch
import
torch
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
tensorrt_llm.llmapi
import
DisaggregatedParams
from
dynamo.trtllm.request_handlers.handlers
import
PrefillHandler
from
dynamo.trtllm.request_handlers.handlers
import
PrefillHandler
from
dynamo.trtllm.tests.request_handlers.utils
import
(
create_mock_encoder_cache
,
run_generate_with_mock_fetch
,
setup_multimodal_config
,
)
from
dynamo.trtllm.tests.utils
import
create_mock_request_handler_config
from
dynamo.trtllm.tests.utils
import
create_mock_request_handler_config
pytestmark
=
[
pytestmark
=
[
...
@@ -20,125 +22,56 @@ pytestmark = [
...
@@ -20,125 +22,56 @@ pytestmark = [
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
gpu_0
,
]
]
FETCH_PATCH_PATH
=
(
@
pytest
.
fixture
"dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder"
def
mock_config
():
)
"""Create a mock RequestHandlerConfig."""
return
create_mock_request_handler_config
(
disaggregation_mode
=
"prefill"
)
@
pytest
.
fixture
def
mock_encoder_cache
():
"""Create a mock EncoderCacheManager."""
cache
=
MagicMock
()
cache
.
get
=
MagicMock
(
return_value
=
None
)
cache
.
set
=
MagicMock
(
return_value
=
True
)
return
cache
@
pytest
.
fixture
def
mock_context
():
"""Create a mock Context."""
ctx
=
MagicMock
()
ctx
.
id
=
MagicMock
(
return_value
=
"test-id"
)
ctx
.
is_stopped
=
MagicMock
(
return_value
=
False
)
ctx
.
is_killed
=
MagicMock
(
return_value
=
False
)
return
ctx
@
pytest
.
fixture
def
image_request
()
->
dict
[
str
,
Any
]:
"""Create a request with one image URL."""
return
{
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"http://example.com/image.jpg"
},
},
],
}
]
}
def
setup_multimodal_config
(
mock_config
):
"""Configure mock_config for multimodal requests."""
mock_config
.
multimodal_processor
=
MagicMock
()
mock_config
.
multimodal_processor
.
extract_prompt_and_media
=
MagicMock
(
return_value
=
(
"text"
,
[
"http://example.com/image.jpg"
],
[])
)
mock_config
.
encode_client
=
MagicMock
()
class
TestPrefillHandlerInit
:
class
TestPrefillHandlerInit
:
"""Tests for PrefillHandler initialization."""
"""Tests for PrefillHandler initialization."""
def
test_init_with_encoder_cache
(
self
,
mock_config
,
mock_encoder_cache
):
def
test_init_with_encoder_cache
(
self
):
"""Test PrefillHandler can be initialized with encoder_cache."""
"""Test PrefillHandler can be initialized with encoder_cache."""
handler
=
PrefillHandler
(
mock_config
,
encoder_cache
=
mock_encoder_cache
)
config
=
create_mock_request_handler_config
(
disaggregation_mode
=
"prefill"
)
cache
=
create_mock_encoder_cache
()
handler
=
PrefillHandler
(
config
,
encoder_cache
=
cache
)
assert
handler
.
engine
==
mock_
config
.
engine
assert
handler
.
engine
==
config
.
engine
assert
handler
.
_encoder_cache
==
mock_encoder_
cache
assert
handler
.
_encoder_cache
==
cache
class
TestPrefillHandlerGenerate
:
class
TestPrefillHandlerGenerate
:
"""Tests for PrefillHandler.generate method."""
"""Tests for PrefillHandler.generate method."""
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_embeddings_passed_to_generate_locally
(
async
def
test_embeddings_passed_to_generate_locally
(
self
):
self
,
mock_config
,
mock_encoder_cache
,
mock_context
,
image_request
"""Cache path: List[Tensor] passed as embeddings."""
):
config
=
create_mock_request_handler_config
(
disaggregation_mode
=
"prefill"
)
"""Test embeddings from fetch_embeddings_from_encoder passed to generate_locally."""
setup_multimodal_config
(
config
,
[
"http://example.com/image.jpg"
])
setup_multimodal_config
(
mock_config
)
handler
=
PrefillHandler
(
config
,
encoder_cache
=
create_mock_encoder_cache
())
handler
=
PrefillHandler
(
mock_config
,
encoder_cache
=
mock_encoder_cache
)
expected_embeddings
=
[
torch
.
randn
(
10
,
256
)]
expected_embeddings
=
[
torch
.
randn
(
10
,
256
)]
captured_embeddings
=
None
async
def
mock_generate_locally
(
request
,
context
,
embeddings
,
ep_params
):
embeddings
,
ep_params
=
await
run_generate_with_mock_fetch
(
nonlocal
captured_embeddings
handler
,
FETCH_PATCH_PATH
,
expected_embeddings
captured_embeddings
=
embeddings
)
yield
{
"result"
:
"mock"
}
with
patch
(
assert
embeddings
is
expected_embeddings
"dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder"
,
assert
ep_params
is
None
new_callable
=
AsyncMock
,
return_value
=
expected_embeddings
,
)
as
mock_fetch
:
with
patch
.
object
(
handler
,
"generate_locally"
,
mock_generate_locally
):
async
for
_
in
handler
.
generate
(
image_request
,
mock_context
):
pass
mock_fetch
.
assert_called_once
()
assert
captured_embeddings
is
expected_embeddings
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_disaggregated_params_passed_to_generate_locally
(
async
def
test_disaggregated_params_passed_to_generate_locally
(
self
):
self
,
mock_config
,
mock_context
,
image_request
"""No-cache path: DisaggregatedParams passed as ep_params."""
):
config
=
create_mock_request_handler_config
(
disaggregation_mode
=
"prefill"
)
"""Test DisaggregatedParams from fetch_embeddings_from_encoder passed to generate_locally."""
setup_multimodal_config
(
config
,
[
"http://example.com/image.jpg"
])
setup_multimodal_config
(
mock_config
)
handler
=
PrefillHandler
(
config
,
encoder_cache
=
None
)
handler
=
PrefillHandler
(
mock_config
,
encoder_cache
=
None
)
expected_params
=
DisaggregatedParams
(
request_type
=
"context_only"
)
expected_params
=
DisaggregatedParams
(
request_type
=
"context_only"
)
captured_ep_params
=
None
embeddings
,
ep_params
=
await
run_generate_with_mock_fetch
(
async
def
mock_generate_locally
(
request
,
context
,
embeddings
,
ep_params
):
handler
,
FETCH_PATCH_PATH
,
expected_params
nonlocal
captured_ep_params
)
captured_ep_params
=
ep_params
yield
{
"result"
:
"mock"
}
assert
embeddings
is
None
assert
ep_params
is
expected_params
with
patch
(
"dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder"
,
new_callable
=
AsyncMock
,
return_value
=
expected_params
,
)
as
mock_fetch
:
with
patch
.
object
(
handler
,
"generate_locally"
,
mock_generate_locally
):
async
for
_
in
handler
.
generate
(
image_request
,
mock_context
):
pass
mock_fetch
.
assert_called_once
()
assert
captured_ep_params
is
expected_params
components/src/dynamo/trtllm/tests/request_handlers/test_trtllm_request_handler_factory.py
View file @
00ea11ff
...
@@ -29,6 +29,14 @@ def mock_config():
...
@@ -29,6 +29,14 @@ def mock_config():
class
TestRequestHandlerFactory
:
class
TestRequestHandlerFactory
:
"""Tests for RequestHandlerFactory."""
"""Tests for RequestHandlerFactory."""
def
test_invalid_mode_raises
(
self
,
mock_config
):
"""Test factory raises ValueError for invalid disaggregation_mode."""
mock_config
.
disaggregation_mode
.
value
=
"invalid_mode"
factory
=
RequestHandlerFactory
()
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid disaggregation_mode"
):
factory
.
get_request_handler
(
mock_config
)
def
test_creates_aggregated_handler
(
self
,
mock_config
):
def
test_creates_aggregated_handler
(
self
,
mock_config
):
"""Test factory creates AggregatedHandler for prefill_and_decode mode."""
"""Test factory creates AggregatedHandler for prefill_and_decode mode."""
factory
=
RequestHandlerFactory
()
factory
=
RequestHandlerFactory
()
...
@@ -44,14 +52,6 @@ class TestRequestHandlerFactory:
...
@@ -44,14 +52,6 @@ class TestRequestHandlerFactory:
assert
isinstance
(
handler
,
PrefillHandler
)
assert
isinstance
(
handler
,
PrefillHandler
)
def
test_invalid_mode_raises
(
self
,
mock_config
):
"""Test factory raises ValueError for invalid disaggregation_mode."""
mock_config
.
disaggregation_mode
.
value
=
"invalid_mode"
factory
=
RequestHandlerFactory
()
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid disaggregation_mode"
):
factory
.
get_request_handler
(
mock_config
)
def
test_prefill_handler_with_encoder_cache
(
self
):
def
test_prefill_handler_with_encoder_cache
(
self
):
"""Test factory creates PrefillHandler with EncoderCacheManager when capacity > 0."""
"""Test factory creates PrefillHandler with EncoderCacheManager when capacity > 0."""
mock_config
=
create_mock_request_handler_config
(
mock_config
=
create_mock_request_handler_config
(
...
...
components/src/dynamo/trtllm/tests/request_handlers/utils.py
0 → 100644
View file @
00ea11ff
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared test utilities for request handler tests."""
from
typing
import
Any
,
List
,
Tuple
from
unittest.mock
import
AsyncMock
,
MagicMock
,
patch
def
create_mock_encoder_cache
()
->
MagicMock
:
"""Create mock EncoderCacheManager."""
cache
=
MagicMock
()
cache
.
get
=
MagicMock
(
return_value
=
None
)
cache
.
set
=
MagicMock
(
return_value
=
True
)
return
cache
def
create_mock_context
(
request_id
:
str
=
"test-id"
)
->
MagicMock
:
"""Create mock Context."""
ctx
=
MagicMock
()
ctx
.
id
=
MagicMock
(
return_value
=
request_id
)
ctx
.
is_stopped
=
MagicMock
(
return_value
=
False
)
ctx
.
is_killed
=
MagicMock
(
return_value
=
False
)
return
ctx
def
setup_multimodal_config
(
config
:
MagicMock
,
image_urls
:
List
[
str
])
->
None
:
"""Configure multimodal_processor and encode_client on config."""
config
.
multimodal_processor
=
MagicMock
()
config
.
multimodal_processor
.
extract_prompt_and_media
=
MagicMock
(
return_value
=
(
"text"
,
image_urls
,
[])
)
config
.
encode_client
=
MagicMock
()
async
def
run_generate_with_mock_fetch
(
handler
:
Any
,
fetch_patch_path
:
str
,
mock_return_value
:
Any
,
)
->
Tuple
[
Any
,
Any
]:
"""
Run handler.generate() with mocked fetch_embeddings_from_encoder.
Args:
handler: Handler instance (PrefillHandler or AggregatedHandler)
fetch_patch_path: Full path to patch fetch_embeddings_from_encoder
mock_return_value: Value to return from mocked fetch
Returns:
Tuple of (captured_embeddings, captured_ep_params)
"""
captured_embeddings
=
None
captured_ep_params
=
None
async
def
mock_generate_locally
(
request
,
context
,
embeddings
,
ep_params
):
nonlocal
captured_embeddings
,
captured_ep_params
captured_embeddings
=
embeddings
captured_ep_params
=
ep_params
yield
{
"result"
:
"mock"
}
request
:
dict
[
str
,
Any
]
=
{
"messages"
:
[]}
with
patch
(
fetch_patch_path
,
new_callable
=
AsyncMock
,
return_value
=
mock_return_value
,
)
as
mock_fetch
:
with
patch
.
object
(
handler
,
"generate_locally"
,
mock_generate_locally
):
async
for
_
in
handler
.
generate
(
request
,
create_mock_context
()):
pass
mock_fetch
.
assert_called_once
()
return
captured_embeddings
,
captured_ep_params
examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/agg.yaml
0 → 100644
View file @
00ea11ff
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size
:
1
moe_expert_parallel_size
:
1
enable_attention_dp
:
false
max_num_tokens
:
8192
max_batch_size
:
16
trust_remote_code
:
true
backend
:
pytorch
enable_chunked_prefill
:
true
kv_cache_config
:
free_gpu_memory_fraction
:
0.60
enable_block_reuse
:
false
cache_transceiver_config
:
backend
:
DEFAULT
examples/backends/trtllm/launch/e_pd_disagg.sh
0 → 100755
View file @
00ea11ff
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# 1 Encode + 1 PD worker for llava-v1.6-mistral-7b-hf
# GPU 0: Encode (vision encoder)
# GPU 1: PD worker (prefill + decode, TP=1)
# Environment variables with defaults
export
DYNAMO_HOME
=
${
DYNAMO_HOME
:-
"/workspace"
}
export
MODEL_PATH
=
${
MODEL_PATH
:-
"llava-hf/llava-v1.6-mistral-7b-hf"
}
export
SERVED_MODEL_NAME
=
${
SERVED_MODEL_NAME
:-
"llava-v1.6-mistral-7b-hf"
}
export
ENCODE_ENGINE_ARGS
=
${
ENCODE_ENGINE_ARGS
:-
"
$DYNAMO_HOME
/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/encode.yaml"
}
export
PD_ENGINE_ARGS
=
${
PD_ENGINE_ARGS
:-
"
$DYNAMO_HOME
/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/agg.yaml"
}
export
ENCODE_CUDA_VISIBLE_DEVICES
=
${
ENCODE_CUDA_VISIBLE_DEVICES
:-
"0"
}
export
ENCODE_ENDPOINT
=
${
ENCODE_ENDPOINT
:-
"dyn://dynamo.tensorrt_llm_encode.generate"
}
export
MODALITY
=
${
MODALITY
:-
"multimodal"
}
export
ALLOWED_LOCAL_MEDIA_PATH
=
${
ALLOWED_LOCAL_MEDIA_PATH
:-
"/tmp"
}
export
MAX_FILE_SIZE_MB
=
${
MAX_FILE_SIZE_MB
:-
50
}
export
DYN_ENCODER_CACHE_CAPACITY_GB
=
${
DYN_ENCODER_CACHE_CAPACITY_GB
:-
4
}
export
CUSTOM_TEMPLATE
=
${
CUSTOM_TEMPLATE
:-
"
$DYNAMO_HOME
/examples/backends/trtllm/templates/llava_multimodal.jinja"
}
# Setup cleanup trap
cleanup
()
{
echo
"Cleaning up background processes..."
kill
$DYNAMO_PID
$ENCODE_PID
$PD_PID_1
2>/dev/null
||
true
wait
$DYNAMO_PID
$ENCODE_PID
$PD_PID_1
2>/dev/null
||
true
echo
"Cleanup complete."
}
trap
cleanup EXIT INT TERM
# run frontend
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
python3
-m
dynamo.frontend &
DYNAMO_PID
=
$!
# run encode worker (vision encoder on GPU 0)
CUDA_VISIBLE_DEVICES
=
$ENCODE_CUDA_VISIBLE_DEVICES
python3
-m
dynamo.trtllm
\
--model-path
"
$MODEL_PATH
"
\
--served-model-name
"
$SERVED_MODEL_NAME
"
\
--extra-engine-args
"
$ENCODE_ENGINE_ARGS
"
\
--modality
"
$MODALITY
"
\
--allowed-local-media-path
"
$ALLOWED_LOCAL_MEDIA_PATH
"
\
--max-file-size-mb
"
$MAX_FILE_SIZE_MB
"
\
--disaggregation-mode
encode &
ENCODE_PID
=
$!
# run PD worker 1 (GPU 1)
CUDA_VISIBLE_DEVICES
=
1 python3
-m
dynamo.trtllm
\
--model-path
"
$MODEL_PATH
"
\
--served-model-name
"
$SERVED_MODEL_NAME
"
\
--extra-engine-args
"
$PD_ENGINE_ARGS
"
\
--modality
"
$MODALITY
"
\
--custom-jinja-template
"
$CUSTOM_TEMPLATE
"
\
--encode-endpoint
"
$ENCODE_ENDPOINT
"
\
--dyn-encoder-cache-capacity-gb
"
$DYN_ENCODER_CACHE_CAPACITY_GB
"
&
PD_PID_1
=
$!
wait
$DYNAMO_PID
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment