Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
5a67b246
Unverified
Commit
5a67b246
authored
Feb 18, 2026
by
jh-nv
Committed by
GitHub
Feb 18, 2026
Browse files
feat: Migrate trtllm configuration (#6297)
parent
9a93eb75
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
657 additions
and
635 deletions
+657
-635
components/src/dynamo/trtllm/args.py
components/src/dynamo/trtllm/args.py
+126
-0
components/src/dynamo/trtllm/backend_args.py
components/src/dynamo/trtllm/backend_args.py
+406
-0
components/src/dynamo/trtllm/main.py
components/src/dynamo/trtllm/main.py
+2
-2
components/src/dynamo/trtllm/tests/test_trtllm_unit.py
components/src/dynamo/trtllm/tests/test_trtllm_unit.py
+88
-4
components/src/dynamo/trtllm/utils/trtllm_utils.py
components/src/dynamo/trtllm/utils/trtllm_utils.py
+7
-598
components/src/dynamo/trtllm/workers/__init__.py
components/src/dynamo/trtllm/workers/__init__.py
+2
-2
components/src/dynamo/trtllm/workers/llm_worker.py
components/src/dynamo/trtllm/workers/llm_worker.py
+22
-25
components/src/dynamo/trtllm/workers/video_diffusion_worker.py
...nents/src/dynamo/trtllm/workers/video_diffusion_worker.py
+4
-4
No files found.
components/src/dynamo/trtllm/args.py
0 → 100644
View file @
5a67b246
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Argument parsing and typed config for Dynamo TRT-LLM."""
import
argparse
import
logging
import
os
import
sys
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
from
dynamo.common.config_dump
import
register_encoder
from
dynamo.common.configuration.groups.runtime_args
import
(
DynamoRuntimeArgGroup
,
DynamoRuntimeConfig
,
)
from
dynamo.common.utils.runtime
import
parse_endpoint
from
dynamo.trtllm.backend_args
import
DynamoTrtllmArgGroup
,
DynamoTrtllmConfig
from
dynamo.trtllm.constants
import
DisaggregationMode
,
Modality
DEFAULT_ENDPOINT_COMPONENT
=
"tensorrt_llm"
DEFAULT_PREFILL_COMPONENT
=
"prefill"
DEFAULT_ENCODE_COMPONENT
=
"tensorrt_llm_encode"
DEFAULT_DIFFUSION_COMPONENT
=
"diffusion"
DEFAULT_ENDPOINT_NAME
=
"generate"
VALID_TRTLLM_CONNECTORS
=
{
"none"
,
"kvbm"
}
class
Config
(
DynamoRuntimeConfig
,
DynamoTrtllmConfig
):
component
:
str
use_kv_events
:
bool
def
validate
(
self
)
->
None
:
DynamoRuntimeConfig
.
validate
(
self
)
DynamoTrtllmConfig
.
validate
(
self
)
# Derive use_kv_events from publish_events_and_metrics
self
.
use_kv_events
=
self
.
publish_events_and_metrics
# fix the connector as trtllm accepts only one connector and it should be in VALID_TRTLLM_CONNECTORS
# while the runtime args accepts a list of connectors
if
self
.
connector
:
if
len
(
self
.
connector
)
>
1
:
raise
ValueError
(
"TRT-LLM supports at most one connector entry. Use `--connector none` or `--connector kvbm`."
)
elif
self
.
connector
[
0
]
not
in
VALID_TRTLLM_CONNECTORS
:
source
=
(
f
"DYN_CONNECTOR environment variable ('
{
os
.
environ
[
'DYN_CONNECTOR'
]
}
')"
if
"DYN_CONNECTOR"
in
os
.
environ
else
f
"shared runtime default ('
{
self
.
connector
[
0
]
}
')"
)
logging
.
warning
(
f
"TRT-LLM does not support connector '
{
self
.
connector
[
0
]
}
' (set via
{
source
}
). "
f
"Supported connectors:
{
VALID_TRTLLM_CONNECTORS
}
. Falling back to 'none'."
)
self
.
connector
=
[
"none"
]
def
has_connector
(
self
,
connector_name
:
str
)
->
bool
:
return
(
self
.
connector
is
not
None
and
len
(
self
.
connector
)
>
0
and
connector_name
==
self
.
connector
[
0
]
)
@
register_encoder
(
Config
)
def
_preprocess_for_encode_config
(
config
:
Config
)
->
Dict
[
str
,
Any
]:
return
config
.
__dict__
def
parse_args
(
argv
:
Optional
[
Sequence
[
str
]]
=
None
)
->
Config
:
"""Parse command-line arguments for the TensorRT-LLM backend."""
cli_args
=
list
(
argv
)
if
argv
is
not
None
else
sys
.
argv
[
1
:]
parser
=
argparse
.
ArgumentParser
(
description
=
"Dynamo TensorRT-LLM worker configuration"
,
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
DynamoRuntimeArgGroup
().
add_arguments
(
parser
)
DynamoTrtllmArgGroup
().
add_arguments
(
parser
)
parsed_args
=
parser
.
parse_args
(
cli_args
)
config
=
Config
.
from_cli_args
(
parsed_args
)
config
.
validate
()
# TODO: move this to common configuration.
if
config
.
custom_jinja_template
:
expanded_template_path
=
os
.
path
.
expanduser
(
os
.
path
.
expandvars
(
config
.
custom_jinja_template
)
)
if
not
os
.
path
.
isfile
(
expanded_template_path
):
raise
FileNotFoundError
(
f
"Custom Jinja template file not found:
{
expanded_template_path
}
"
)
config
.
custom_jinja_template
=
expanded_template_path
else
:
config
.
custom_jinja_template
=
None
endpoint
=
config
.
endpoint
or
_default_endpoint
(
namespace
=
config
.
namespace
,
modality
=
config
.
modality
,
disaggregation_mode
=
config
.
disaggregation_mode
,
)
parsed_namespace
,
parsed_component_name
,
parsed_endpoint_name
=
parse_endpoint
(
endpoint
)
config
.
namespace
=
parsed_namespace
config
.
component
=
parsed_component_name
config
.
endpoint
=
parsed_endpoint_name
return
config
def
_default_endpoint
(
namespace
:
str
,
modality
:
Modality
,
disaggregation_mode
:
DisaggregationMode
)
->
str
:
if
modality
==
Modality
.
VIDEO_DIFFUSION
:
component_name
=
DEFAULT_DIFFUSION_COMPONENT
elif
disaggregation_mode
==
DisaggregationMode
.
ENCODE
:
component_name
=
DEFAULT_ENCODE_COMPONENT
elif
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
component_name
=
DEFAULT_PREFILL_COMPONENT
else
:
component_name
=
DEFAULT_ENDPOINT_COMPONENT
return
f
"dyn://
{
namespace
}
.
{
component_name
}
.
{
DEFAULT_ENDPOINT_NAME
}
"
components/src/dynamo/trtllm/backend_args.py
0 → 100644
View file @
5a67b246
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Dynamo TRT-LLM backend configuration ArgGroup."""
from
typing
import
Optional
from
tensorrt_llm.llmapi
import
BuildConfig
from
dynamo.common.configuration.arg_group
import
ArgGroup
from
dynamo.common.configuration.config_base
import
ConfigBase
from
dynamo.common.configuration.utils
import
add_argument
,
add_negatable_bool_argument
from
.
import
__version__
from
.constants
import
DisaggregationMode
,
Modality
DEFAULT_MODEL
=
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
class
DynamoTrtllmArgGroup
(
ArgGroup
):
"""TensorRT-LLM-specific Dynamo wrapper configuration."""
def
add_arguments
(
self
,
parser
)
->
None
:
parser
.
add_argument
(
"--version"
,
action
=
"version"
,
version
=
f
"Dynamo Backend TRTLLM
{
__version__
}
"
,
)
g
=
parser
.
add_argument_group
(
"Dynamo TRT-LLM Options"
)
add_argument
(
g
,
flag_name
=
"--model"
,
env_var
=
"DYN_TRTLLM_MODEL"
,
default
=
DEFAULT_MODEL
,
obsolete_flag
=
"--model-path"
,
help
=
(
"Path to disk model or HuggingFace model identifier to load. "
),
)
add_argument
(
g
,
flag_name
=
"--served-model-name"
,
env_var
=
"DYN_TRTLLM_SERVED_MODEL_NAME"
,
default
=
None
,
help
=
"Name to serve the model under. Defaults to deriving it from model path."
,
)
add_argument
(
g
,
flag_name
=
"--tensor-parallel-size"
,
env_var
=
"DYN_TRTLLM_TENSOR_PARALLEL_SIZE"
,
default
=
1
,
arg_type
=
int
,
help
=
"Tensor parallelism size."
,
)
add_argument
(
g
,
flag_name
=
"--pipeline-parallel-size"
,
env_var
=
"DYN_TRTLLM_PIPELINE_PARALLEL_SIZE"
,
default
=
1
,
arg_type
=
int
,
help
=
"Pipeline parallelism size."
,
)
add_argument
(
g
,
flag_name
=
"--expert-parallel-size"
,
env_var
=
"DYN_TRTLLM_EXPERT_PARALLEL_SIZE"
,
default
=
None
,
arg_type
=
int
,
help
=
"Expert parallelism size."
,
)
add_negatable_bool_argument
(
g
,
flag_name
=
"--enable-attention-dp"
,
env_var
=
"DYN_TRTLLM_ENABLE_ATTENTION_DP"
,
default
=
False
,
help
=
"Enable attention data parallelism. When enabled, attention_dp_size equals tensor_parallel_size."
,
)
add_argument
(
g
,
flag_name
=
"--kv-block-size"
,
env_var
=
"DYN_TRTLLM_KV_BLOCK_SIZE"
,
default
=
32
,
arg_type
=
int
,
help
=
"Size of a KV cache block."
,
)
add_argument
(
g
,
flag_name
=
"--gpus-per-node"
,
env_var
=
"DYN_TRTLLM_GPUS_PER_NODE"
,
default
=
None
,
arg_type
=
int
,
help
=
"Number of GPUs per node. If not provided, inferred from the environment."
,
)
add_argument
(
g
,
flag_name
=
"--max-batch-size"
,
env_var
=
"DYN_TRTLLM_MAX_BATCH_SIZE"
,
default
=
BuildConfig
.
model_fields
[
"max_batch_size"
].
default
,
arg_type
=
int
,
help
=
"Maximum number of requests that the engine can schedule."
,
)
add_argument
(
g
,
flag_name
=
"--max-num-tokens"
,
env_var
=
"DYN_TRTLLM_MAX_NUM_TOKENS"
,
default
=
BuildConfig
.
model_fields
[
"max_num_tokens"
].
default
,
arg_type
=
int
,
help
=
"Maximum number of batched input tokens after padding is removed in each batch."
,
)
add_argument
(
g
,
flag_name
=
"--max-seq-len"
,
env_var
=
"DYN_TRTLLM_MAX_SEQ_LEN"
,
default
=
BuildConfig
.
model_fields
[
"max_seq_len"
].
default
,
arg_type
=
int
,
help
=
"Maximum total length of one request, including prompt and outputs. If unspecified, the value is deduced from the model config."
,
)
add_argument
(
g
,
flag_name
=
"--max-beam-width"
,
env_var
=
"DYN_TRTLLM_MAX_BEAM_WIDTH"
,
default
=
BuildConfig
.
model_fields
[
"max_beam_width"
].
default
,
arg_type
=
int
,
help
=
"Maximum number of beams for beam search decoding."
,
)
add_argument
(
g
,
flag_name
=
"--free-gpu-memory-fraction"
,
env_var
=
"DYN_TRTLLM_FREE_GPU_MEMORY_FRACTION"
,
default
=
0.9
,
arg_type
=
float
,
help
=
"Free GPU memory fraction reserved for KV Cache, after model weights and buffers are allocated."
,
)
add_argument
(
g
,
flag_name
=
"--extra-engine-args"
,
env_var
=
"DYN_TRTLLM_EXTRA_ENGINE_ARGS"
,
default
=
""
,
help
=
"Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine."
,
)
add_argument
(
g
,
flag_name
=
"--override-engine-args"
,
env_var
=
"DYN_TRTLLM_OVERRIDE_ENGINE_ARGS"
,
default
=
""
,
help
=
"Python dictionary string to override specific engine arguments from the YAML file. "
'Example:
\'
{"tensor_parallel_size": 2, "kv_cache_config": {"enable_block_reuse": false}}
\'
'
,
)
add_negatable_bool_argument
(
g
,
flag_name
=
"--publish-events-and-metrics"
,
env_var
=
"DYN_TRTLLM_PUBLISH_EVENTS_AND_METRICS"
,
default
=
False
,
help
=
"If set, publish events and metrics to Dynamo components."
,
)
add_argument
(
g
,
flag_name
=
"--disaggregation-mode"
,
env_var
=
"DYN_TRTLLM_DISAGGREGATION_MODE"
,
default
=
DisaggregationMode
.
AGGREGATED
.
value
,
choices
=
[
mode
.
value
for
mode
in
DisaggregationMode
],
help
=
"Mode to use for disaggregation."
,
)
add_argument
(
g
,
flag_name
=
"--modality"
,
env_var
=
"DYN_TRTLLM_MODALITY"
,
default
=
Modality
.
TEXT
.
value
,
choices
=
[
m
.
value
for
m
in
Modality
],
help
=
"Modality to use for the model."
,
)
add_argument
(
g
,
flag_name
=
"--encode-endpoint"
,
env_var
=
"DYN_TRTLLM_ENCODE_ENDPOINT"
,
default
=
""
,
help
=
"Endpoint (in 'dyn://namespace.component.endpoint' format) for the encode worker."
,
)
add_argument
(
g
,
flag_name
=
"--allowed-local-media-path"
,
env_var
=
"DYN_TRTLLM_ALLOWED_LOCAL_MEDIA_PATH"
,
default
=
""
,
help
=
"Path to a directory that is allowed to be accessed by the model."
,
)
add_argument
(
g
,
flag_name
=
"--max-file-size-mb"
,
env_var
=
"DYN_TRTLLM_MAX_FILE_SIZE_MB"
,
default
=
50
,
arg_type
=
int
,
help
=
"Maximum size of downloadable embedding files/Image URLs."
,
)
diffusion_group
=
parser
.
add_argument_group
(
"Diffusion Options [Experimental]"
,
"Options for video_diffusion modality"
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--output-dir"
,
env_var
=
"DYN_TRTLLM_OUTPUT_DIR"
,
default
=
"/tmp/dynamo_videos"
,
help
=
"Directory to store generated videos/images."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--default-height"
,
env_var
=
"DYN_TRTLLM_DEFAULT_HEIGHT"
,
default
=
480
,
arg_type
=
int
,
help
=
"Default video/image height in pixels."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--default-width"
,
env_var
=
"DYN_TRTLLM_DEFAULT_WIDTH"
,
default
=
832
,
arg_type
=
int
,
help
=
"Default video/image width in pixels."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--default-num-frames"
,
env_var
=
"DYN_TRTLLM_DEFAULT_NUM_FRAMES"
,
default
=
81
,
arg_type
=
int
,
help
=
"Default number of frames for video generation."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--default-num-inference-steps"
,
env_var
=
"DYN_TRTLLM_DEFAULT_NUM_INFERENCE_STEPS"
,
default
=
50
,
arg_type
=
int
,
help
=
"Default number of inference steps."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--default-guidance-scale"
,
env_var
=
"DYN_TRTLLM_DEFAULT_GUIDANCE_SCALE"
,
default
=
5.0
,
arg_type
=
float
,
help
=
"Default CFG guidance scale."
,
)
add_negatable_bool_argument
(
diffusion_group
,
flag_name
=
"--enable-teacache"
,
env_var
=
"DYN_TRTLLM_ENABLE_TEACACHE"
,
default
=
False
,
help
=
"Enable TeaCache optimization for faster generation."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--teacache-thresh"
,
env_var
=
"DYN_TRTLLM_TEACACHE_THRESH"
,
default
=
0.2
,
arg_type
=
float
,
help
=
"TeaCache threshold."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--attn-type"
,
env_var
=
"DYN_TRTLLM_ATTN_TYPE"
,
default
=
"default"
,
choices
=
[
"default"
,
"sage-attn"
,
"sparse-videogen"
,
"sparse-videogen2"
],
help
=
"Attention type for diffusion models."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--linear-type"
,
env_var
=
"DYN_TRTLLM_LINEAR_TYPE"
,
default
=
"default"
,
choices
=
[
"default"
,
"trtllm-fp8-blockwise"
,
"trtllm-fp8-per-tensor"
,
"trtllm-nvfp4"
,
],
help
=
"Linear type for quantization."
,
)
add_negatable_bool_argument
(
diffusion_group
,
flag_name
=
"--disable-torch-compile"
,
env_var
=
"DYN_TRTLLM_DISABLE_TORCH_COMPILE"
,
default
=
False
,
help
=
"Disable torch.compile optimization."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--torch-compile-mode"
,
env_var
=
"DYN_TRTLLM_TORCH_COMPILE_MODE"
,
default
=
"default"
,
choices
=
[
"default"
,
"reduce-overhead"
,
"max-autotune"
],
help
=
"torch.compile mode."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--dit-dp-size"
,
env_var
=
"DYN_TRTLLM_DIT_DP_SIZE"
,
default
=
1
,
arg_type
=
int
,
help
=
"Data parallel size for DiT."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--dit-tp-size"
,
env_var
=
"DYN_TRTLLM_DIT_TP_SIZE"
,
default
=
1
,
arg_type
=
int
,
help
=
"Tensor parallel size for DiT."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--dit-ulysses-size"
,
env_var
=
"DYN_TRTLLM_DIT_ULYSSES_SIZE"
,
default
=
1
,
arg_type
=
int
,
help
=
"Ulysses parallel size for DiT."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--dit-ring-size"
,
env_var
=
"DYN_TRTLLM_DIT_RING_SIZE"
,
default
=
1
,
arg_type
=
int
,
help
=
"Ring parallel size for DiT."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--dit-cfg-size"
,
env_var
=
"DYN_TRTLLM_DIT_CFG_SIZE"
,
default
=
1
,
arg_type
=
int
,
help
=
"CFG parallel size for DiT."
,
)
add_argument
(
diffusion_group
,
flag_name
=
"--dit-fsdp-size"
,
env_var
=
"DYN_TRTLLM_DIT_FSDP_SIZE"
,
default
=
1
,
arg_type
=
int
,
help
=
"FSDP size for DiT."
,
)
add_negatable_bool_argument
(
diffusion_group
,
flag_name
=
"--enable-async-cpu-offload"
,
env_var
=
"DYN_TRTLLM_ENABLE_ASYNC_CPU_OFFLOAD"
,
default
=
False
,
help
=
"Enable async CPU offload for memory efficiency."
,
)
class
DynamoTrtllmConfig
(
ConfigBase
):
"""Configuration for Dynamo TRT-LLM backend-specific options."""
model
:
str
served_model_name
:
Optional
[
str
]
=
None
tensor_parallel_size
:
int
pipeline_parallel_size
:
int
expert_parallel_size
:
Optional
[
int
]
enable_attention_dp
:
bool
kv_block_size
:
int
gpus_per_node
:
Optional
[
int
]
=
None
max_batch_size
:
int
max_num_tokens
:
int
max_seq_len
:
int
max_beam_width
:
int
free_gpu_memory_fraction
:
float
extra_engine_args
:
str
override_engine_args
:
str
publish_events_and_metrics
:
bool
disaggregation_mode
:
DisaggregationMode
modality
:
Modality
encode_endpoint
:
str
allowed_local_media_path
:
str
max_file_size_mb
:
int
output_dir
:
str
default_height
:
int
default_width
:
int
default_num_frames
:
int
default_num_inference_steps
:
int
default_guidance_scale
:
float
enable_teacache
:
bool
teacache_thresh
:
float
attn_type
:
str
linear_type
:
str
disable_torch_compile
:
bool
torch_compile_mode
:
str
dit_dp_size
:
int
dit_tp_size
:
int
dit_ulysses_size
:
int
dit_ring_size
:
int
dit_cfg_size
:
int
dit_fsdp_size
:
int
enable_async_cpu_offload
:
bool
def
validate
(
self
)
->
None
:
if
isinstance
(
self
.
disaggregation_mode
,
str
):
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
disaggregation_mode
)
if
isinstance
(
self
.
modality
,
str
):
self
.
modality
=
Modality
(
self
.
modality
)
if
not
self
.
served_model_name
:
self
.
served_model_name
=
None
components/src/dynamo/trtllm/main.py
View file @
5a67b246
...
@@ -20,14 +20,14 @@ import uvloop
...
@@ -20,14 +20,14 @@ import uvloop
from
dynamo.common.utils.runtime
import
create_runtime
from
dynamo.common.utils.runtime
import
create_runtime
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.trtllm.
utils.trtllm_util
s
import
cmd_lin
e_args
from
dynamo.trtllm.
arg
s
import
pars
e_args
from
dynamo.trtllm.workers
import
init_worker
from
dynamo.trtllm.workers
import
init_worker
configure_dynamo_logging
()
configure_dynamo_logging
()
async
def
worker
():
async
def
worker
():
config
=
cmd_lin
e_args
()
config
=
pars
e_args
()
shutdown_event
=
asyncio
.
Event
()
shutdown_event
=
asyncio
.
Event
()
runtime
,
_
=
create_runtime
(
runtime
,
_
=
create_runtime
(
...
...
components/src/dynamo/trtllm/tests/test_trtllm_unit.py
View file @
5a67b246
...
@@ -16,8 +16,9 @@ if not torch.cuda.is_available():
...
@@ -16,8 +16,9 @@ if not torch.cuda.is_available():
allow_module_level
=
True
,
allow_module_level
=
True
,
)
)
from
dynamo.trtllm.args
import
Config
,
parse_args
from
dynamo.trtllm.tests.conftest
import
make_cli_args_fixture
from
dynamo.trtllm.tests.conftest
import
make_cli_args_fixture
from
dynamo.trtllm.utils.trtllm_utils
import
cmd_line_args
from
dynamo.trtllm.utils.trtllm_utils
import
deep_update
# Get path relative to this test file
# Get path relative to this test file
REPO_ROOT
=
Path
(
__file__
).
resolve
().
parents
[
5
]
REPO_ROOT
=
Path
(
__file__
).
resolve
().
parents
[
5
]
...
@@ -51,13 +52,13 @@ def test_custom_jinja_template_invalid_path(mock_trtllm_cli):
...
@@ -51,13 +52,13 @@ def test_custom_jinja_template_invalid_path(mock_trtllm_cli):
FileNotFoundError
,
FileNotFoundError
,
match
=
re
.
escape
(
f
"Custom Jinja template file not found:
{
invalid_path
}
"
),
match
=
re
.
escape
(
f
"Custom Jinja template file not found:
{
invalid_path
}
"
),
):
):
cmd_lin
e_args
()
#
This will read in from argv
pars
e_args
()
#
Reads from argv set by fixture
def
test_custom_jinja_template_valid_path
(
mock_trtllm_cli
):
def
test_custom_jinja_template_valid_path
(
mock_trtllm_cli
):
"""Test that valid absolute path is stored correctly."""
"""Test that valid absolute path is stored correctly."""
mock_trtllm_cli
(
model
=
"Qwen/Qwen3-0.6B"
,
custom_jinja_template
=
JINJA_TEMPLATE_PATH
)
mock_trtllm_cli
(
model
=
"Qwen/Qwen3-0.6B"
,
custom_jinja_template
=
JINJA_TEMPLATE_PATH
)
config
=
cmd_lin
e_args
()
config
=
pars
e_args
()
assert
config
.
custom_jinja_template
==
JINJA_TEMPLATE_PATH
,
(
assert
config
.
custom_jinja_template
==
JINJA_TEMPLATE_PATH
,
(
f
"Expected custom_jinja_template value to be
{
JINJA_TEMPLATE_PATH
}
, "
f
"Expected custom_jinja_template value to be
{
JINJA_TEMPLATE_PATH
}
, "
...
@@ -73,10 +74,93 @@ def test_custom_jinja_template_env_var_expansion(monkeypatch, mock_trtllm_cli):
...
@@ -73,10 +74,93 @@ def test_custom_jinja_template_env_var_expansion(monkeypatch, mock_trtllm_cli):
cli_path
=
"$JINJA_DIR/custom_template.jinja"
cli_path
=
"$JINJA_DIR/custom_template.jinja"
mock_trtllm_cli
(
model
=
"Qwen/Qwen3-0.6B"
,
custom_jinja_template
=
cli_path
)
mock_trtllm_cli
(
model
=
"Qwen/Qwen3-0.6B"
,
custom_jinja_template
=
cli_path
)
config
=
cmd_lin
e_args
()
config
=
pars
e_args
()
assert
"$JINJA_DIR"
not
in
config
.
custom_jinja_template
assert
"$JINJA_DIR"
not
in
config
.
custom_jinja_template
assert
config
.
custom_jinja_template
==
JINJA_TEMPLATE_PATH
,
(
assert
config
.
custom_jinja_template
==
JINJA_TEMPLATE_PATH
,
(
f
"Expected custom_jinja_template value to be
{
JINJA_TEMPLATE_PATH
}
, "
f
"Expected custom_jinja_template value to be
{
JINJA_TEMPLATE_PATH
}
, "
f
"got
{
config
.
custom_jinja_template
}
"
f
"got
{
config
.
custom_jinja_template
}
"
)
)
# ---- Tests for trtllm/args.py (Config, parse_args) ----
def
test_parse_args_returns_config_with_expected_attrs
(
monkeypatch
):
"""parse_args returns a Config instance with model, component, and endpoint set."""
monkeypatch
.
delenv
(
"DYN_NAMESPACE"
,
raising
=
False
)
monkeypatch
.
delenv
(
"DYN_TRTLLM_MODEL"
,
raising
=
False
)
config
=
parse_args
([
"--namespace"
,
"testns"
,
"--model-path"
,
"Qwen/Qwen3-0.6B"
])
assert
isinstance
(
config
,
Config
)
assert
config
.
model
==
"Qwen/Qwen3-0.6B"
assert
config
.
namespace
==
"testns"
assert
config
.
component
==
"tensorrt_llm"
assert
config
.
endpoint
==
"generate"
def
test_config_use_kv_events_derived_from_publish_events
(
monkeypatch
):
"""Config.validate sets use_kv_events from publish_events_and_metrics."""
monkeypatch
.
delenv
(
"DYN_TRTLLM_PUBLISH_EVENTS"
,
raising
=
False
)
config
=
parse_args
([
"--publish-events"
])
assert
config
.
publish_events_and_metrics
is
True
assert
config
.
use_kv_events
is
True
config_off
=
parse_args
([
"--no-publish-events"
])
assert
config_off
.
publish_events_and_metrics
is
False
assert
config_off
.
use_kv_events
is
False
def
test_config_has_connector
(
monkeypatch
):
"""Config.has_connector returns True only for the single configured connector."""
monkeypatch
.
delenv
(
"DYN_CONNECTOR"
,
raising
=
False
)
config_none
=
parse_args
([
"--connector"
,
"none"
])
assert
config_none
.
has_connector
(
"none"
)
is
True
assert
config_none
.
has_connector
(
"kvbm"
)
is
False
config_kvbm
=
parse_args
([
"--connector"
,
"kvbm"
])
assert
config_kvbm
.
has_connector
(
"kvbm"
)
is
True
assert
config_kvbm
.
has_connector
(
"none"
)
is
False
def
test_config_multiple_connectors_fails
(
monkeypatch
):
"""Config.validate fails if multiple connectors are provided."""
monkeypatch
.
delenv
(
"DYN_CONNECTOR"
,
raising
=
False
)
with
pytest
.
raises
(
ValueError
,
match
=
"TRT-LLM supports at most one connector entry. Use `--connector none` or `--connector kvbm`."
,
):
parse_args
([
"--connector"
,
"none"
,
"kvbm"
])
# ---- Tests for trtllm_utils.deep_update ----
def
test_deep_update_nested_merge
():
"""deep_update merges nested dicts without removing existing keys."""
target
=
{
"a"
:
1
,
"b"
:
{
"x"
:
10
,
"y"
:
20
}}
source
=
{
"b"
:
{
"y"
:
21
,
"z"
:
30
}}
deep_update
(
target
,
source
)
assert
target
==
{
"a"
:
1
,
"b"
:
{
"x"
:
10
,
"y"
:
21
,
"z"
:
30
}}
def
test_deep_update_overwrites_scalar_with_value
():
"""deep_update overwrites a key with a non-dict value."""
target
=
{
"a"
:
1
,
"b"
:
{
"x"
:
10
}}
source
=
{
"a"
:
2
,
"b"
:
99
}
deep_update
(
target
,
source
)
assert
target
==
{
"a"
:
2
,
"b"
:
99
}
def
test_deep_update_empty_source_unchanged
():
"""deep_update with empty source leaves target unchanged."""
target
=
{
"a"
:
1
,
"b"
:
{
"x"
:
10
}}
deep_update
(
target
,
{})
assert
target
==
{
"a"
:
1
,
"b"
:
{
"x"
:
10
}}
def
test_deep_update_adds_new_keys
():
"""deep_update adds new keys from source that are not in target."""
target
=
{
"a"
:
1
}
source
=
{
"b"
:
2
,
"c"
:
{
"nested"
:
3
}}
deep_update
(
target
,
source
)
assert
target
==
{
"a"
:
1
,
"b"
:
2
,
"c"
:
{
"nested"
:
3
}}
components/src/dynamo/trtllm/utils/trtllm_utils.py
View file @
5a67b246
This diff is collapsed.
Click to expand it.
components/src/dynamo/trtllm/workers/__init__.py
View file @
5a67b246
...
@@ -20,8 +20,8 @@ import asyncio
...
@@ -20,8 +20,8 @@ import asyncio
import
logging
import
logging
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.trtllm.args
import
Config
from
dynamo.trtllm.constants
import
Modality
from
dynamo.trtllm.constants
import
Modality
from
dynamo.trtllm.utils.trtllm_utils
import
Config
from
dynamo.trtllm.workers.llm_worker
import
init_llm_worker
from
dynamo.trtllm.workers.llm_worker
import
init_llm_worker
...
@@ -40,7 +40,7 @@ async def init_worker(
...
@@ -40,7 +40,7 @@ async def init_worker(
"""
"""
logging
.
info
(
f
"Initializing worker with modality=
{
config
.
modality
}
"
)
logging
.
info
(
f
"Initializing worker with modality=
{
config
.
modality
}
"
)
modality
=
Modality
(
config
.
modality
)
modality
=
config
.
modality
if
Modality
.
is_diffusion
(
modality
):
if
Modality
.
is_diffusion
(
modality
):
if
modality
==
Modality
.
VIDEO_DIFFUSION
:
if
modality
==
Modality
.
VIDEO_DIFFUSION
:
...
...
components/src/dynamo/trtllm/workers/llm_worker.py
View file @
5a67b246
...
@@ -45,6 +45,7 @@ from dynamo.llm import (
...
@@ -45,6 +45,7 @@ from dynamo.llm import (
register_model
,
register_model
,
)
)
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.trtllm.args
import
Config
from
dynamo.trtllm.constants
import
DisaggregationMode
from
dynamo.trtllm.constants
import
DisaggregationMode
from
dynamo.trtllm.engine
import
Backend
,
TensorRTLLMEngine
,
get_llm_engine
from
dynamo.trtllm.engine
import
Backend
,
TensorRTLLMEngine
,
get_llm_engine
from
dynamo.trtllm.health_check
import
TrtllmHealthCheckPayload
from
dynamo.trtllm.health_check
import
TrtllmHealthCheckPayload
...
@@ -54,7 +55,7 @@ from dynamo.trtllm.request_handlers.handlers import (
...
@@ -54,7 +55,7 @@ from dynamo.trtllm.request_handlers.handlers import (
RequestHandlerConfig
,
RequestHandlerConfig
,
RequestHandlerFactory
,
RequestHandlerFactory
,
)
)
from
dynamo.trtllm.utils.trtllm_utils
import
Config
,
deep_update
from
dynamo.trtllm.utils.trtllm_utils
import
deep_update
# Default buffer size for kv cache events.
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
=
1024
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
=
1024
...
@@ -92,17 +93,17 @@ async def get_engine_runtime_config(
...
@@ -92,17 +93,17 @@ async def get_engine_runtime_config(
def
build_kv_connector_config
(
config
:
Config
):
def
build_kv_connector_config
(
config
:
Config
):
if
config
.
connector
is
not
None
:
if
config
.
connector
:
if
config
.
connector
==
"kvbm"
:
if
config
.
connector
[
0
]
==
"kvbm"
:
return
KvCacheConnectorConfig
(
return
KvCacheConnectorConfig
(
connector_module
=
"kvbm.trtllm_integration.connector"
,
connector_module
=
"kvbm.trtllm_integration.connector"
,
connector_scheduler_class
=
"DynamoKVBMConnectorLeader"
,
connector_scheduler_class
=
"DynamoKVBMConnectorLeader"
,
connector_worker_class
=
"DynamoKVBMConnectorWorker"
,
connector_worker_class
=
"DynamoKVBMConnectorWorker"
,
)
)
elif
config
.
connector
==
"none"
:
elif
config
.
connector
[
0
]
==
"none"
:
return
None
return
None
else
:
else
:
logging
.
error
(
f
"Invalid connector:
{
config
.
connector
}
"
)
logging
.
error
(
f
"Invalid connector:
{
config
.
connector
[
0
]
}
"
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
return
None
return
None
...
@@ -138,7 +139,7 @@ async def init_llm_worker(
...
@@ -138,7 +139,7 @@ async def init_llm_worker(
component
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
)
component
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
)
# Convert model path to Path object if it's a local path, otherwise keep as string
# Convert model path to Path object if it's a local path, otherwise keep as string
model_path
=
str
(
config
.
model
_path
)
model_path
=
str
(
config
.
model
)
if
config
.
gpus_per_node
is
None
:
if
config
.
gpus_per_node
is
None
:
gpus_per_node
=
device_count
()
gpus_per_node
=
device_count
()
...
@@ -151,7 +152,7 @@ async def init_llm_worker(
...
@@ -151,7 +152,7 @@ async def init_llm_worker(
free_gpu_memory_fraction
=
config
.
free_gpu_memory_fraction
free_gpu_memory_fraction
=
config
.
free_gpu_memory_fraction
)
)
if
config
.
connector
is
not
None
and
"kvbm"
in
config
.
connector
:
if
config
.
has_
connector
(
"kvbm"
)
:
kv_cache_config
.
enable_partial_reuse
=
False
kv_cache_config
.
enable_partial_reuse
=
False
dynamic_batch_config
=
DynamicBatchConfig
(
dynamic_batch_config
=
DynamicBatchConfig
(
...
@@ -275,15 +276,13 @@ async def init_llm_worker(
...
@@ -275,15 +276,13 @@ async def init_llm_worker(
if
config
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
if
config
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
model_type
=
ModelType
.
Prefill
model_type
=
ModelType
.
Prefill
else
:
else
:
model_type
=
parse_endpoint_types
(
config
.
dyn_endpoint_types
)
model_type
=
parse_endpoint_types
(
config
.
endpoint_types
)
logging
.
info
(
logging
.
info
(
f
"Registering model with endpoint types:
{
config
.
endpoint_types
}
"
)
f
"Registering model with endpoint types:
{
config
.
dyn_endpoint_types
}
"
)
# Warn if custom template provided but chat endpoint not enabled
# Warn if custom template provided but chat endpoint not enabled
if
config
.
custom_jinja_template
and
"chat"
not
in
config
.
dyn_
endpoint_types
:
if
config
.
custom_jinja_template
and
"chat"
not
in
config
.
endpoint_types
:
logging
.
warning
(
logging
.
warning
(
"Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --
dyn-
endpoint-types. "
"Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --endpoint-types. "
"The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
"The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
)
)
...
@@ -298,12 +297,10 @@ async def init_llm_worker(
...
@@ -298,12 +297,10 @@ async def init_llm_worker(
if
modality
==
"multimodal"
:
if
modality
==
"multimodal"
:
engine_args
[
"skip_tokenizer_init"
]
=
False
engine_args
[
"skip_tokenizer_init"
]
=
False
model_config
=
AutoConfig
.
from_pretrained
(
model_config
=
AutoConfig
.
from_pretrained
(
config
.
model
,
trust_remote_code
=
True
)
config
.
model_path
,
trust_remote_code
=
True
)
multimodal_processor
=
MultimodalRequestProcessor
(
multimodal_processor
=
MultimodalRequestProcessor
(
model_type
=
model_config
.
model_type
,
model_type
=
model_config
.
model_type
,
model_dir
=
config
.
model
_path
,
model_dir
=
config
.
model
,
max_file_size_mb
=
config
.
max_file_size_mb
,
max_file_size_mb
=
config
.
max_file_size_mb
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
allowed_local_media_path
=
config
.
allowed_local_media_path
,
allowed_local_media_path
=
config
.
allowed_local_media_path
,
...
@@ -322,7 +319,7 @@ async def init_llm_worker(
...
@@ -322,7 +319,7 @@ async def init_llm_worker(
)
)
# Prepare model name for metrics
# Prepare model name for metrics
model_name_for_metrics
=
config
.
served_model_name
or
config
.
model
_path
model_name_for_metrics
=
config
.
served_model_name
or
config
.
model
# Construct Prometheus gauges directly; passed through to the engine and publisher
# Construct Prometheus gauges directly; passed through to the engine and publisher
# via explicit parameters (no module-level global).
# via explicit parameters (no module-level global).
...
@@ -357,8 +354,8 @@ async def init_llm_worker(
...
@@ -357,8 +354,8 @@ async def init_llm_worker(
# Both parameters control the same thing: how many requests can be processed simultaneously
# Both parameters control the same thing: how many requests can be processed simultaneously
runtime_config
.
max_num_seqs
=
config
.
max_batch_size
runtime_config
.
max_num_seqs
=
config
.
max_batch_size
runtime_config
.
max_num_batched_tokens
=
config
.
max_num_tokens
runtime_config
.
max_num_batched_tokens
=
config
.
max_num_tokens
runtime_config
.
reasoning_parser
=
config
.
reasoning_parser
runtime_config
.
reasoning_parser
=
config
.
dyn_
reasoning_parser
runtime_config
.
tool_call_parser
=
config
.
tool_call_parser
runtime_config
.
tool_call_parser
=
config
.
dyn_
tool_call_parser
# Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
# Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
runtime_config
.
enable_local_indexer
=
(
runtime_config
.
enable_local_indexer
=
(
config
.
enable_local_indexer
config
.
enable_local_indexer
...
@@ -386,7 +383,7 @@ async def init_llm_worker(
...
@@ -386,7 +383,7 @@ async def init_llm_worker(
metrics_collector
=
None
metrics_collector
=
None
if
config
.
publish_events_and_metrics
:
if
config
.
publish_events_and_metrics
:
try
:
try
:
model_name_for_metrics
=
config
.
served_model_name
or
config
.
model
_path
model_name_for_metrics
=
config
.
served_model_name
or
config
.
model
metrics_collector
=
MetricsCollector
(
metrics_collector
=
MetricsCollector
(
{
"model_name"
:
model_name_for_metrics
,
"engine_type"
:
"trtllm"
}
{
"model_name"
:
model_name_for_metrics
,
"engine_type"
:
"trtllm"
}
)
)
...
@@ -430,7 +427,7 @@ async def init_llm_worker(
...
@@ -430,7 +427,7 @@ async def init_llm_worker(
metrics_collector
=
metrics_collector
,
metrics_collector
=
metrics_collector
,
kv_block_size
=
config
.
kv_block_size
,
kv_block_size
=
config
.
kv_block_size
,
shutdown_event
=
shutdown_event
,
shutdown_event
=
shutdown_event
,
encoder_cache_capacity_gb
=
config
.
encoder
_cache_capacity_gb
,
encoder_cache_capacity_gb
=
config
.
multimodal_embedding
_cache_capacity_gb
,
)
)
# Register the model with runtime config
# Register the model with runtime config
...
@@ -441,7 +438,7 @@ async def init_llm_worker(
...
@@ -441,7 +438,7 @@ async def init_llm_worker(
model_input
,
model_input
,
model_type
,
model_type
,
endpoint
,
endpoint
,
config
.
model
_path
,
config
.
model
,
config
.
served_model_name
,
config
.
served_model_name
,
kv_cache_block_size
=
config
.
kv_block_size
,
kv_cache_block_size
=
config
.
kv_block_size
,
runtime_config
=
runtime_config
,
runtime_config
=
runtime_config
,
...
@@ -457,8 +454,8 @@ async def init_llm_worker(
...
@@ -457,8 +454,8 @@ async def init_llm_worker(
kv_listener
=
runtime
.
namespace
(
config
.
namespace
).
component
(
kv_listener
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
config
.
component
)
)
# Use model
_path
as fallback if served_model_name is not provided
# Use model as fallback if served_model_name is not provided
model_name_for_metrics
=
config
.
served_model_name
or
config
.
model
_path
model_name_for_metrics
=
config
.
served_model_name
or
config
.
model
metrics_labels
=
[
metrics_labels
=
[
(
(
prometheus_names
.
labels
.
MODEL
,
prometheus_names
.
labels
.
MODEL
,
...
...
components/src/dynamo/trtllm/workers/video_diffusion_worker.py
View file @
5a67b246
...
@@ -12,7 +12,7 @@ import logging
...
@@ -12,7 +12,7 @@ import logging
from
dynamo.llm
import
ModelInput
,
ModelType
,
register_model
from
dynamo.llm
import
ModelInput
,
ModelType
,
register_model
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.trtllm.
utils.trtllm_util
s
import
Config
from
dynamo.trtllm.
arg
s
import
Config
async
def
init_video_diffusion_worker
(
async
def
init_video_diffusion_worker
(
...
@@ -58,7 +58,7 @@ async def init_video_diffusion_worker(
...
@@ -58,7 +58,7 @@ async def init_video_diffusion_worker(
discovery_backend
=
config
.
discovery_backend
,
discovery_backend
=
config
.
discovery_backend
,
request_plane
=
config
.
request_plane
,
request_plane
=
config
.
request_plane
,
event_plane
=
config
.
event_plane
,
event_plane
=
config
.
event_plane
,
model_path
=
config
.
model
_path
,
model_path
=
config
.
model
,
served_model_name
=
config
.
served_model_name
,
served_model_name
=
config
.
served_model_name
,
output_dir
=
config
.
output_dir
,
output_dir
=
config
.
output_dir
,
default_height
=
config
.
default_height
,
default_height
=
config
.
default_height
,
...
@@ -93,7 +93,7 @@ async def init_video_diffusion_worker(
...
@@ -93,7 +93,7 @@ async def init_video_diffusion_worker(
handler
=
VideoGenerationHandler
(
component
,
engine
,
diffusion_config
)
handler
=
VideoGenerationHandler
(
component
,
engine
,
diffusion_config
)
# Register the model with Dynamo's discovery system
# Register the model with Dynamo's discovery system
model_name
=
config
.
served_model_name
or
config
.
model
_path
model_name
=
config
.
served_model_name
or
config
.
model
# Use ModelType.Videos for video generation
# Use ModelType.Videos for video generation
if
not
hasattr
(
ModelType
,
"Videos"
):
if
not
hasattr
(
ModelType
,
"Videos"
):
...
@@ -111,7 +111,7 @@ async def init_video_diffusion_worker(
...
@@ -111,7 +111,7 @@ async def init_video_diffusion_worker(
ModelInput
.
Text
,
ModelInput
.
Text
,
model_type
,
model_type
,
endpoint
,
endpoint
,
config
.
model
_path
,
config
.
model
,
model_name
,
model_name
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment