Unverified Commit 5a67b246 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: Migrate trtllm configuration (#6297)

parent 9a93eb75
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Argument parsing and typed config for Dynamo TRT-LLM."""
import argparse
import logging
import os
import sys
from typing import Any, Dict, Optional, Sequence
from dynamo.common.config_dump import register_encoder
from dynamo.common.configuration.groups.runtime_args import (
DynamoRuntimeArgGroup,
DynamoRuntimeConfig,
)
from dynamo.common.utils.runtime import parse_endpoint
from dynamo.trtllm.backend_args import DynamoTrtllmArgGroup, DynamoTrtllmConfig
from dynamo.trtllm.constants import DisaggregationMode, Modality
DEFAULT_ENDPOINT_COMPONENT = "tensorrt_llm"
DEFAULT_PREFILL_COMPONENT = "prefill"
DEFAULT_ENCODE_COMPONENT = "tensorrt_llm_encode"
DEFAULT_DIFFUSION_COMPONENT = "diffusion"
DEFAULT_ENDPOINT_NAME = "generate"
VALID_TRTLLM_CONNECTORS = {"none", "kvbm"}
class Config(DynamoRuntimeConfig, DynamoTrtllmConfig):
component: str
use_kv_events: bool
def validate(self) -> None:
DynamoRuntimeConfig.validate(self)
DynamoTrtllmConfig.validate(self)
# Derive use_kv_events from publish_events_and_metrics
self.use_kv_events = self.publish_events_and_metrics
# fix the connector as trtllm accepts only one connector and it should be in VALID_TRTLLM_CONNECTORS
# while the runtime args accepts a list of connectors
if self.connector:
if len(self.connector) > 1:
raise ValueError(
"TRT-LLM supports at most one connector entry. Use `--connector none` or `--connector kvbm`."
)
elif self.connector[0] not in VALID_TRTLLM_CONNECTORS:
source = (
f"DYN_CONNECTOR environment variable ('{os.environ['DYN_CONNECTOR']}')"
if "DYN_CONNECTOR" in os.environ
else f"shared runtime default ('{self.connector[0]}')"
)
logging.warning(
f"TRT-LLM does not support connector '{self.connector[0]}' (set via {source}). "
f"Supported connectors: {VALID_TRTLLM_CONNECTORS}. Falling back to 'none'."
)
self.connector = ["none"]
def has_connector(self, connector_name: str) -> bool:
return (
self.connector is not None
and len(self.connector) > 0
and connector_name == self.connector[0]
)
@register_encoder(Config)
def _preprocess_for_encode_config(config: Config) -> Dict[str, Any]:
return config.__dict__
def parse_args(argv: Optional[Sequence[str]] = None) -> Config:
"""Parse command-line arguments for the TensorRT-LLM backend."""
cli_args = list(argv) if argv is not None else sys.argv[1:]
parser = argparse.ArgumentParser(
description="Dynamo TensorRT-LLM worker configuration",
formatter_class=argparse.RawTextHelpFormatter,
)
DynamoRuntimeArgGroup().add_arguments(parser)
DynamoTrtllmArgGroup().add_arguments(parser)
parsed_args = parser.parse_args(cli_args)
config = Config.from_cli_args(parsed_args)
config.validate()
# TODO: move this to common configuration.
if config.custom_jinja_template:
expanded_template_path = os.path.expanduser(
os.path.expandvars(config.custom_jinja_template)
)
if not os.path.isfile(expanded_template_path):
raise FileNotFoundError(
f"Custom Jinja template file not found: {expanded_template_path}"
)
config.custom_jinja_template = expanded_template_path
else:
config.custom_jinja_template = None
endpoint = config.endpoint or _default_endpoint(
namespace=config.namespace,
modality=config.modality,
disaggregation_mode=config.disaggregation_mode,
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint
)
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
return config
def _default_endpoint(
namespace: str, modality: Modality, disaggregation_mode: DisaggregationMode
) -> str:
if modality == Modality.VIDEO_DIFFUSION:
component_name = DEFAULT_DIFFUSION_COMPONENT
elif disaggregation_mode == DisaggregationMode.ENCODE:
component_name = DEFAULT_ENCODE_COMPONENT
elif disaggregation_mode == DisaggregationMode.PREFILL:
component_name = DEFAULT_PREFILL_COMPONENT
else:
component_name = DEFAULT_ENDPOINT_COMPONENT
return f"dyn://{namespace}.{component_name}.{DEFAULT_ENDPOINT_NAME}"
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Dynamo TRT-LLM backend configuration ArgGroup."""
from typing import Optional
from tensorrt_llm.llmapi import BuildConfig
from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from . import __version__
from .constants import DisaggregationMode, Modality
DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
class DynamoTrtllmArgGroup(ArgGroup):
"""TensorRT-LLM-specific Dynamo wrapper configuration."""
def add_arguments(self, parser) -> None:
parser.add_argument(
"--version",
action="version",
version=f"Dynamo Backend TRTLLM {__version__}",
)
g = parser.add_argument_group("Dynamo TRT-LLM Options")
add_argument(
g,
flag_name="--model",
env_var="DYN_TRTLLM_MODEL",
default=DEFAULT_MODEL,
obsolete_flag="--model-path",
help=("Path to disk model or HuggingFace model identifier to load. "),
)
add_argument(
g,
flag_name="--served-model-name",
env_var="DYN_TRTLLM_SERVED_MODEL_NAME",
default=None,
help="Name to serve the model under. Defaults to deriving it from model path.",
)
add_argument(
g,
flag_name="--tensor-parallel-size",
env_var="DYN_TRTLLM_TENSOR_PARALLEL_SIZE",
default=1,
arg_type=int,
help="Tensor parallelism size.",
)
add_argument(
g,
flag_name="--pipeline-parallel-size",
env_var="DYN_TRTLLM_PIPELINE_PARALLEL_SIZE",
default=1,
arg_type=int,
help="Pipeline parallelism size.",
)
add_argument(
g,
flag_name="--expert-parallel-size",
env_var="DYN_TRTLLM_EXPERT_PARALLEL_SIZE",
default=None,
arg_type=int,
help="Expert parallelism size.",
)
add_negatable_bool_argument(
g,
flag_name="--enable-attention-dp",
env_var="DYN_TRTLLM_ENABLE_ATTENTION_DP",
default=False,
help="Enable attention data parallelism. When enabled, attention_dp_size equals tensor_parallel_size.",
)
add_argument(
g,
flag_name="--kv-block-size",
env_var="DYN_TRTLLM_KV_BLOCK_SIZE",
default=32,
arg_type=int,
help="Size of a KV cache block.",
)
add_argument(
g,
flag_name="--gpus-per-node",
env_var="DYN_TRTLLM_GPUS_PER_NODE",
default=None,
arg_type=int,
help="Number of GPUs per node. If not provided, inferred from the environment.",
)
add_argument(
g,
flag_name="--max-batch-size",
env_var="DYN_TRTLLM_MAX_BATCH_SIZE",
default=BuildConfig.model_fields["max_batch_size"].default,
arg_type=int,
help="Maximum number of requests that the engine can schedule.",
)
add_argument(
g,
flag_name="--max-num-tokens",
env_var="DYN_TRTLLM_MAX_NUM_TOKENS",
default=BuildConfig.model_fields["max_num_tokens"].default,
arg_type=int,
help="Maximum number of batched input tokens after padding is removed in each batch.",
)
add_argument(
g,
flag_name="--max-seq-len",
env_var="DYN_TRTLLM_MAX_SEQ_LEN",
default=BuildConfig.model_fields["max_seq_len"].default,
arg_type=int,
help="Maximum total length of one request, including prompt and outputs. If unspecified, the value is deduced from the model config.",
)
add_argument(
g,
flag_name="--max-beam-width",
env_var="DYN_TRTLLM_MAX_BEAM_WIDTH",
default=BuildConfig.model_fields["max_beam_width"].default,
arg_type=int,
help="Maximum number of beams for beam search decoding.",
)
add_argument(
g,
flag_name="--free-gpu-memory-fraction",
env_var="DYN_TRTLLM_FREE_GPU_MEMORY_FRACTION",
default=0.9,
arg_type=float,
help="Free GPU memory fraction reserved for KV Cache, after model weights and buffers are allocated.",
)
add_argument(
g,
flag_name="--extra-engine-args",
env_var="DYN_TRTLLM_EXTRA_ENGINE_ARGS",
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
)
add_argument(
g,
flag_name="--override-engine-args",
env_var="DYN_TRTLLM_OVERRIDE_ENGINE_ARGS",
default="",
help="Python dictionary string to override specific engine arguments from the YAML file. "
'Example: \'{"tensor_parallel_size": 2, "kv_cache_config": {"enable_block_reuse": false}}\'',
)
add_negatable_bool_argument(
g,
flag_name="--publish-events-and-metrics",
env_var="DYN_TRTLLM_PUBLISH_EVENTS_AND_METRICS",
default=False,
help="If set, publish events and metrics to Dynamo components.",
)
add_argument(
g,
flag_name="--disaggregation-mode",
env_var="DYN_TRTLLM_DISAGGREGATION_MODE",
default=DisaggregationMode.AGGREGATED.value,
choices=[mode.value for mode in DisaggregationMode],
help="Mode to use for disaggregation.",
)
add_argument(
g,
flag_name="--modality",
env_var="DYN_TRTLLM_MODALITY",
default=Modality.TEXT.value,
choices=[m.value for m in Modality],
help="Modality to use for the model.",
)
add_argument(
g,
flag_name="--encode-endpoint",
env_var="DYN_TRTLLM_ENCODE_ENDPOINT",
default="",
help="Endpoint (in 'dyn://namespace.component.endpoint' format) for the encode worker.",
)
add_argument(
g,
flag_name="--allowed-local-media-path",
env_var="DYN_TRTLLM_ALLOWED_LOCAL_MEDIA_PATH",
default="",
help="Path to a directory that is allowed to be accessed by the model.",
)
add_argument(
g,
flag_name="--max-file-size-mb",
env_var="DYN_TRTLLM_MAX_FILE_SIZE_MB",
default=50,
arg_type=int,
help="Maximum size of downloadable embedding files/Image URLs.",
)
diffusion_group = parser.add_argument_group(
"Diffusion Options [Experimental]",
"Options for video_diffusion modality",
)
add_argument(
diffusion_group,
flag_name="--output-dir",
env_var="DYN_TRTLLM_OUTPUT_DIR",
default="/tmp/dynamo_videos",
help="Directory to store generated videos/images.",
)
add_argument(
diffusion_group,
flag_name="--default-height",
env_var="DYN_TRTLLM_DEFAULT_HEIGHT",
default=480,
arg_type=int,
help="Default video/image height in pixels.",
)
add_argument(
diffusion_group,
flag_name="--default-width",
env_var="DYN_TRTLLM_DEFAULT_WIDTH",
default=832,
arg_type=int,
help="Default video/image width in pixels.",
)
add_argument(
diffusion_group,
flag_name="--default-num-frames",
env_var="DYN_TRTLLM_DEFAULT_NUM_FRAMES",
default=81,
arg_type=int,
help="Default number of frames for video generation.",
)
add_argument(
diffusion_group,
flag_name="--default-num-inference-steps",
env_var="DYN_TRTLLM_DEFAULT_NUM_INFERENCE_STEPS",
default=50,
arg_type=int,
help="Default number of inference steps.",
)
add_argument(
diffusion_group,
flag_name="--default-guidance-scale",
env_var="DYN_TRTLLM_DEFAULT_GUIDANCE_SCALE",
default=5.0,
arg_type=float,
help="Default CFG guidance scale.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-teacache",
env_var="DYN_TRTLLM_ENABLE_TEACACHE",
default=False,
help="Enable TeaCache optimization for faster generation.",
)
add_argument(
diffusion_group,
flag_name="--teacache-thresh",
env_var="DYN_TRTLLM_TEACACHE_THRESH",
default=0.2,
arg_type=float,
help="TeaCache threshold.",
)
add_argument(
diffusion_group,
flag_name="--attn-type",
env_var="DYN_TRTLLM_ATTN_TYPE",
default="default",
choices=["default", "sage-attn", "sparse-videogen", "sparse-videogen2"],
help="Attention type for diffusion models.",
)
add_argument(
diffusion_group,
flag_name="--linear-type",
env_var="DYN_TRTLLM_LINEAR_TYPE",
default="default",
choices=[
"default",
"trtllm-fp8-blockwise",
"trtllm-fp8-per-tensor",
"trtllm-nvfp4",
],
help="Linear type for quantization.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--disable-torch-compile",
env_var="DYN_TRTLLM_DISABLE_TORCH_COMPILE",
default=False,
help="Disable torch.compile optimization.",
)
add_argument(
diffusion_group,
flag_name="--torch-compile-mode",
env_var="DYN_TRTLLM_TORCH_COMPILE_MODE",
default="default",
choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode.",
)
add_argument(
diffusion_group,
flag_name="--dit-dp-size",
env_var="DYN_TRTLLM_DIT_DP_SIZE",
default=1,
arg_type=int,
help="Data parallel size for DiT.",
)
add_argument(
diffusion_group,
flag_name="--dit-tp-size",
env_var="DYN_TRTLLM_DIT_TP_SIZE",
default=1,
arg_type=int,
help="Tensor parallel size for DiT.",
)
add_argument(
diffusion_group,
flag_name="--dit-ulysses-size",
env_var="DYN_TRTLLM_DIT_ULYSSES_SIZE",
default=1,
arg_type=int,
help="Ulysses parallel size for DiT.",
)
add_argument(
diffusion_group,
flag_name="--dit-ring-size",
env_var="DYN_TRTLLM_DIT_RING_SIZE",
default=1,
arg_type=int,
help="Ring parallel size for DiT.",
)
add_argument(
diffusion_group,
flag_name="--dit-cfg-size",
env_var="DYN_TRTLLM_DIT_CFG_SIZE",
default=1,
arg_type=int,
help="CFG parallel size for DiT.",
)
add_argument(
diffusion_group,
flag_name="--dit-fsdp-size",
env_var="DYN_TRTLLM_DIT_FSDP_SIZE",
default=1,
arg_type=int,
help="FSDP size for DiT.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-async-cpu-offload",
env_var="DYN_TRTLLM_ENABLE_ASYNC_CPU_OFFLOAD",
default=False,
help="Enable async CPU offload for memory efficiency.",
)
class DynamoTrtllmConfig(ConfigBase):
"""Configuration for Dynamo TRT-LLM backend-specific options."""
model: str
served_model_name: Optional[str] = None
tensor_parallel_size: int
pipeline_parallel_size: int
expert_parallel_size: Optional[int]
enable_attention_dp: bool
kv_block_size: int
gpus_per_node: Optional[int] = None
max_batch_size: int
max_num_tokens: int
max_seq_len: int
max_beam_width: int
free_gpu_memory_fraction: float
extra_engine_args: str
override_engine_args: str
publish_events_and_metrics: bool
disaggregation_mode: DisaggregationMode
modality: Modality
encode_endpoint: str
allowed_local_media_path: str
max_file_size_mb: int
output_dir: str
default_height: int
default_width: int
default_num_frames: int
default_num_inference_steps: int
default_guidance_scale: float
enable_teacache: bool
teacache_thresh: float
attn_type: str
linear_type: str
disable_torch_compile: bool
torch_compile_mode: str
dit_dp_size: int
dit_tp_size: int
dit_ulysses_size: int
dit_ring_size: int
dit_cfg_size: int
dit_fsdp_size: int
enable_async_cpu_offload: bool
def validate(self) -> None:
if isinstance(self.disaggregation_mode, str):
self.disaggregation_mode = DisaggregationMode(self.disaggregation_mode)
if isinstance(self.modality, str):
self.modality = Modality(self.modality)
if not self.served_model_name:
self.served_model_name = None
......@@ -20,14 +20,14 @@ import uvloop
from dynamo.common.utils.runtime import create_runtime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.utils.trtllm_utils import cmd_line_args
from dynamo.trtllm.args import parse_args
from dynamo.trtllm.workers import init_worker
configure_dynamo_logging()
async def worker():
config = cmd_line_args()
config = parse_args()
shutdown_event = asyncio.Event()
runtime, _ = create_runtime(
......
......@@ -16,8 +16,9 @@ if not torch.cuda.is_available():
allow_module_level=True,
)
from dynamo.trtllm.args import Config, parse_args
from dynamo.trtllm.tests.conftest import make_cli_args_fixture
from dynamo.trtllm.utils.trtllm_utils import cmd_line_args
from dynamo.trtllm.utils.trtllm_utils import deep_update
# Get path relative to this test file
REPO_ROOT = Path(__file__).resolve().parents[5]
......@@ -51,13 +52,13 @@ def test_custom_jinja_template_invalid_path(mock_trtllm_cli):
FileNotFoundError,
match=re.escape(f"Custom Jinja template file not found: {invalid_path}"),
):
cmd_line_args() # This will read in from argv
parse_args() # Reads from argv set by fixture
def test_custom_jinja_template_valid_path(mock_trtllm_cli):
"""Test that valid absolute path is stored correctly."""
mock_trtllm_cli(model="Qwen/Qwen3-0.6B", custom_jinja_template=JINJA_TEMPLATE_PATH)
config = cmd_line_args()
config = parse_args()
assert config.custom_jinja_template == JINJA_TEMPLATE_PATH, (
f"Expected custom_jinja_template value to be {JINJA_TEMPLATE_PATH}, "
......@@ -73,10 +74,93 @@ def test_custom_jinja_template_env_var_expansion(monkeypatch, mock_trtllm_cli):
cli_path = "$JINJA_DIR/custom_template.jinja"
mock_trtllm_cli(model="Qwen/Qwen3-0.6B", custom_jinja_template=cli_path)
config = cmd_line_args()
config = parse_args()
assert "$JINJA_DIR" not in config.custom_jinja_template
assert config.custom_jinja_template == JINJA_TEMPLATE_PATH, (
f"Expected custom_jinja_template value to be {JINJA_TEMPLATE_PATH}, "
f"got {config.custom_jinja_template}"
)
# ---- Tests for trtllm/args.py (Config, parse_args) ----
def test_parse_args_returns_config_with_expected_attrs(monkeypatch):
"""parse_args returns a Config instance with model, component, and endpoint set."""
monkeypatch.delenv("DYN_NAMESPACE", raising=False)
monkeypatch.delenv("DYN_TRTLLM_MODEL", raising=False)
config = parse_args(["--namespace", "testns", "--model-path", "Qwen/Qwen3-0.6B"])
assert isinstance(config, Config)
assert config.model == "Qwen/Qwen3-0.6B"
assert config.namespace == "testns"
assert config.component == "tensorrt_llm"
assert config.endpoint == "generate"
def test_config_use_kv_events_derived_from_publish_events(monkeypatch):
"""Config.validate sets use_kv_events from publish_events_and_metrics."""
monkeypatch.delenv("DYN_TRTLLM_PUBLISH_EVENTS", raising=False)
config = parse_args(["--publish-events"])
assert config.publish_events_and_metrics is True
assert config.use_kv_events is True
config_off = parse_args(["--no-publish-events"])
assert config_off.publish_events_and_metrics is False
assert config_off.use_kv_events is False
def test_config_has_connector(monkeypatch):
"""Config.has_connector returns True only for the single configured connector."""
monkeypatch.delenv("DYN_CONNECTOR", raising=False)
config_none = parse_args(["--connector", "none"])
assert config_none.has_connector("none") is True
assert config_none.has_connector("kvbm") is False
config_kvbm = parse_args(["--connector", "kvbm"])
assert config_kvbm.has_connector("kvbm") is True
assert config_kvbm.has_connector("none") is False
def test_config_multiple_connectors_fails(monkeypatch):
"""Config.validate fails if multiple connectors are provided."""
monkeypatch.delenv("DYN_CONNECTOR", raising=False)
with pytest.raises(
ValueError,
match="TRT-LLM supports at most one connector entry. Use `--connector none` or `--connector kvbm`.",
):
parse_args(["--connector", "none", "kvbm"])
# ---- Tests for trtllm_utils.deep_update ----
def test_deep_update_nested_merge():
"""deep_update merges nested dicts without removing existing keys."""
target = {"a": 1, "b": {"x": 10, "y": 20}}
source = {"b": {"y": 21, "z": 30}}
deep_update(target, source)
assert target == {"a": 1, "b": {"x": 10, "y": 21, "z": 30}}
def test_deep_update_overwrites_scalar_with_value():
"""deep_update overwrites a key with a non-dict value."""
target = {"a": 1, "b": {"x": 10}}
source = {"a": 2, "b": 99}
deep_update(target, source)
assert target == {"a": 2, "b": 99}
def test_deep_update_empty_source_unchanged():
"""deep_update with empty source leaves target unchanged."""
target = {"a": 1, "b": {"x": 10}}
deep_update(target, {})
assert target == {"a": 1, "b": {"x": 10}}
def test_deep_update_adds_new_keys():
"""deep_update adds new keys from source that are not in target."""
target = {"a": 1}
source = {"b": 2, "c": {"nested": 3}}
deep_update(target, source)
assert target == {"a": 1, "b": 2, "c": {"nested": 3}}
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
from typing import Optional
"""Shared utilities for the TRT-LLM backend."""
from tensorrt_llm.llmapi import BuildConfig
from collections.abc import Mapping
from typing import Any
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.config_dump import add_config_dump_args, register_encoder
from dynamo.common.utils.runtime import parse_endpoint
from dynamo.trtllm import __version__
from dynamo.trtllm.constants import DisaggregationMode, Modality
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
# Default endpoints for TensorRT-LLM workers
DEFAULT_ENDPOINT = (
f"dyn://{DYN_NAMESPACE}.tensorrt_llm.generate" # Decode/aggregated workers
)
DEFAULT_PREFILL_ENDPOINT = f"dyn://{DYN_NAMESPACE}.prefill.generate" # Prefill workers
DEFAULT_ENCODE_ENDPOINT = (
f"dyn://{DYN_NAMESPACE}.tensorrt_llm_encode.generate" # Encode workers
)
DEFAULT_DIFFUSION_ENDPOINT = (
f"dyn://{DYN_NAMESPACE}.diffusion.generate" # Diffusion workers
)
DEFAULT_MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DEFAULT_VIDEO_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
DEFAULT_DISAGGREGATION_MODE = DisaggregationMode.AGGREGATED
class Config:
"""Command line parameters or defaults"""
def __init__(self) -> None:
self.namespace: str = ""
self.component: str = ""
self.endpoint: str = ""
self.model_path: str = ""
self.served_model_name: Optional[str] = None
self.tensor_parallel_size: int = 1
self.pipeline_parallel_size: int = 1
self.expert_parallel_size: Optional[int] = None
self.enable_attention_dp: bool = False
self.kv_block_size: int = 32
self.gpus_per_node: Optional[int] = None
self.max_batch_size: int = BuildConfig.model_fields["max_batch_size"].default
self.max_num_tokens: int = BuildConfig.model_fields["max_num_tokens"].default
self.max_seq_len: int = BuildConfig.model_fields["max_seq_len"].default
self.max_beam_width: int = BuildConfig.model_fields["max_beam_width"].default
self.free_gpu_memory_fraction: float = 0.9
self.extra_engine_args: str = ""
self.override_engine_args: str = ""
self.publish_events_and_metrics: bool = False
self.disaggregation_mode: DisaggregationMode = DEFAULT_DISAGGREGATION_MODE
self.encode_endpoint: str = ""
self.modality: str = "text"
self.allowed_local_media_path: str = ""
self.max_file_size_mb: int = 50
self.encoder_cache_capacity_gb: float = 0
self.reasoning_parser: Optional[str] = None
self.tool_call_parser: Optional[str] = None
self.dump_config_to: Optional[str] = None
self.custom_jinja_template: Optional[str] = None
self.dyn_endpoint_types: str = "chat,completions"
self.discovery_backend: str = ""
self.request_plane: str = ""
self.event_plane: str = ""
self.enable_local_indexer: bool = True
# Whether to enable NATS for KV events (derived from publish_events_and_metrics)
self.use_kv_events: bool = False
# Diffusion-specific config (only used when modality is video_diffusion or image_diffusion)
self.output_dir: str = "/tmp/dynamo_videos"
self.default_height: int = 480
self.default_width: int = 832
self.default_num_frames: int = 81
self.default_num_inference_steps: int = 50
self.default_guidance_scale: float = 5.0
self.enable_teacache: bool = False
self.teacache_thresh: float = 0.2
self.attn_type: str = "default"
self.linear_type: str = "default"
self.disable_torch_compile: bool = False
self.torch_compile_mode: str = "default"
self.dit_dp_size: int = 1
self.dit_tp_size: int = 1
self.dit_ulysses_size: int = 1
self.dit_ring_size: int = 1
self.dit_cfg_size: int = 1
self.dit_fsdp_size: int = 1
self.enable_async_cpu_offload: bool = False
def __str__(self) -> str:
return (
f"Config(namespace={self.namespace}, "
f"component={self.component}, "
f"endpoint={self.endpoint}, "
f"model_path={self.model_path}, "
f"served_model_name={self.served_model_name}, "
f"tensor_parallel_size={self.tensor_parallel_size}, "
f"pipeline_parallel_size={self.pipeline_parallel_size}, "
f"expert_parallel_size={self.expert_parallel_size}, "
f"enable_attention_dp={self.enable_attention_dp}, "
f"kv_block_size={self.kv_block_size}, "
f"gpus_per_node={self.gpus_per_node}, "
f"max_batch_size={self.max_batch_size}, "
f"max_num_tokens={self.max_num_tokens}, "
f"max_seq_len={self.max_seq_len}, "
f"max_beam_width={self.max_beam_width}, "
f"free_gpu_memory_fraction={self.free_gpu_memory_fraction}, "
f"extra_engine_args={self.extra_engine_args}, "
f"override_engine_args={self.override_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
f"encode_endpoint={self.encode_endpoint}, "
f"modality={self.modality}, "
f"allowed_local_media_path={self.allowed_local_media_path}, "
f"max_file_size_mb={self.max_file_size_mb}, "
f"encoder_cache_capacity_gb={self.encoder_cache_capacity_gb}, "
f"reasoning_parser={self.reasoning_parser}, "
f"tool_call_parser={self.tool_call_parser}, "
f"dump_config_to={self.dump_config_to}, "
f"custom_jinja_template={self.custom_jinja_template}, "
f"discovery_backend={self.discovery_backend}, "
f"request_plane={self.request_plane}, "
f"event_plane={self.event_plane}, "
f"enable_local_indexer={self.enable_local_indexer}, "
f"use_kv_events={self.use_kv_events}, "
f"output_dir={self.output_dir}, "
f"dit_dp_size={self.dit_dp_size}, "
f"dit_tp_size={self.dit_tp_size})"
)
@register_encoder(Config)
def _preprocess_for_encode_config(
obj: Config,
) -> dict: # pyright: ignore[reportUnusedFunction]
"""Convert Config object to dictionary for encoding."""
return obj.__dict__
def cmd_line_args():
"""Parse command-line arguments for the TensorRT-LLM backend.
Returns:
Config: Parsed configuration object.
"""
parser = argparse.ArgumentParser(
description="TensorRT-LLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--version", action="version", version=f"Dynamo Backend TRTLLM {__version__}"
)
parser.add_argument(
"--endpoint",
type=str,
default="",
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT} for decode/aggregated, {DEFAULT_PREFILL_ENDPOINT} for prefill workers, or {DEFAULT_ENCODE_ENDPOINT} for encode workers",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL_PATH,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL_PATH}",
)
parser.add_argument(
"--served-model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Tensor parallelism size."
)
parser.add_argument(
"--pipeline-parallel-size",
type=int,
default=None,
help="Pipeline parallelism size.",
)
parser.add_argument(
"--expert-parallel-size",
type=int,
default=None,
help="expert parallelism size.",
)
parser.add_argument(
"--enable-attention-dp",
action="store_true",
help="Enable attention data parallelism. When enabled, attention_dp_size equals tensor_parallel_size.",
)
# IMPORTANT: We should ideally not expose this to users. We should be able to
# query the block size from the TRTLLM engine.
parser.add_argument(
"--kv-block-size", type=int, default=32, help="Size of a KV cache block."
)
parser.add_argument(
"--gpus-per-node",
type=int,
default=None,
help="Number of GPUs per node. If not provided, will be inferred from the environment.",
)
parser.add_argument(
"--max-batch-size",
type=int,
default=BuildConfig.model_fields["max_batch_size"].default,
help="Maximum number of requests that the engine can schedule.",
)
parser.add_argument(
"--max-num-tokens",
type=int,
default=BuildConfig.model_fields["max_num_tokens"].default,
help="Maximum number of batched input tokens after padding is removed in each batch.",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=BuildConfig.model_fields["max_seq_len"].default,
help="Maximum total length of one request, including prompt and outputs. "
"If unspecified, the value is deduced from the model config.",
)
parser.add_argument(
"--max-beam-width",
type=int,
default=BuildConfig.model_fields["max_beam_width"].default,
help="Maximum number of beams for beam search decoding.",
)
parser.add_argument(
"--free-gpu-memory-fraction",
type=float,
default=None,
help="Free GPU memory fraction reserved for KV Cache, after allocating model weights and buffers.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
)
parser.add_argument(
"--override-engine-args",
type=str,
default="",
help='Python dictionary string to override specific engine arguments from the YAML file. Example: \'{"tensor_parallel_size": 2, "kv_cache_config": {"enable_block_reuse": false}}\'',
)
parser.add_argument(
"--publish-events-and-metrics",
action="store_true",
help="If set, publish events and metrics to the dynamo components.",
)
parser.add_argument(
"--disaggregation-mode",
type=str,
default=DEFAULT_DISAGGREGATION_MODE,
choices=[mode.value for mode in DisaggregationMode],
help=f"Mode to use for disaggregation. Default: {DEFAULT_DISAGGREGATION_MODE}",
)
parser.add_argument(
"--use-nixl-connect",
type=bool,
default=False,
help="Use NIXL Connect for communication between workers.",
)
parser.add_argument(
"--modality",
type=str,
default="text",
choices=[m.value for m in Modality],
help="Modality to use for the model. Default: text. "
"Options: text (LLM), multimodal (VLM), video_diffusion.",
)
parser.add_argument(
"--encode-endpoint",
type=str,
default="",
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) for the encode worker. e.g. {DEFAULT_ENCODE_ENDPOINT}",
)
parser.add_argument(
"--allowed-local-media-path",
type=str,
default="",
help="Path to a directory that is allowed to be accessed by the model. Default: empty",
)
parser.add_argument(
"--max-file-size-mb",
type=int,
default=50,
help="Maximum size of downloadable embedding files/Image URLs. Default: 50MB",
)
parser.add_argument(
"--dyn-encoder-cache-capacity-gb",
type=float,
default=0,
help="Capacity of the encoder cache in GB for multimodal embeddings. Default: 0",
)
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser.add_argument(
"--dyn-tool-call-parser",
type=str,
default=None,
choices=get_tool_parser_names(),
help="Tool call parser name for the model.",
)
parser.add_argument(
"--dyn-reasoning-parser",
type=str,
default=None,
choices=get_reasoning_parser_names(),
help="Reasoning parser name for the model. If not specified, no reasoning parsing is performed.",
)
parser.add_argument(
"--connector",
type=str,
default="none",
choices=["none", "kvbm"],
help="Connector to use for the model.",
)
add_config_dump_args(parser)
parser.add_argument(
"--custom-jinja-template",
type=str,
default=None,
help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.",
)
parser.add_argument(
"--dyn-endpoint-types",
type=str,
default="chat,completions",
help="Comma-separated list of endpoint types to enable. Options: 'chat', 'completions'. Default: 'chat,completions'. Use 'completions' for models without chat templates.",
)
parser.add_argument(
"--discovery-backend",
type=str,
choices=["kubernetes", "etcd", "file", "mem"],
default=os.environ.get("DYN_DISCOVERY_BACKEND", "etcd"),
help="Discovery backend: kubernetes (K8s API), etcd (distributed KV), file (local filesystem), mem (in-memory). Etcd uses the ETCD_* env vars (e.g. ETCD_ENDPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
)
parser.add_argument(
"--request-plane",
type=str,
choices=["nats", "http", "tcp"],
default=os.environ.get("DYN_REQUEST_PLANE", "tcp"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
)
parser.add_argument(
"--event-plane",
type=str,
choices=["nats", "zmq"],
default=os.environ.get("DYN_EVENT_PLANE", "nats"),
help="Determines how events are published [nats|zmq]",
)
parser.add_argument(
"--durable-kv-events",
action="store_true",
default=os.environ.get("DYN_DURABLE_KV_EVENTS", "false").lower() == "true",
help="Enable durable KV events using NATS JetStream instead of the local indexer. By default, local indexer is enabled for lower latency. Use this flag when you need durability and multi-replica router consistency. Requires NATS with JetStream enabled. Can also be set via DYN_DURABLE_KV_EVENTS=true env var.",
)
# Diffusion-specific options (only used when modality is video_diffusion or image_diffusion)
diffusion_group = parser.add_argument_group(
"Diffusion Options [Experimental]",
"Options for video_diffusion modality",
)
diffusion_group.add_argument(
"--output-dir",
type=str,
default="/tmp/dynamo_videos",
help="Directory to store generated videos/images. Default: /tmp/dynamo_videos",
)
diffusion_group.add_argument(
"--default-height",
type=int,
default=480,
help="Default video/image height in pixels. Default: 480",
)
diffusion_group.add_argument(
"--default-width",
type=int,
default=832,
help="Default video/image width in pixels. Default: 832",
)
diffusion_group.add_argument(
"--default-num-frames",
type=int,
default=81,
help="Default number of frames for video generation. Default: 81",
)
diffusion_group.add_argument(
"--default-num-inference-steps",
type=int,
default=50,
help="Default number of inference steps. Default: 50",
)
diffusion_group.add_argument(
"--default-guidance-scale",
type=float,
default=5.0,
help="Default CFG guidance scale. Default: 5.0",
)
diffusion_group.add_argument(
"--enable-teacache",
action="store_true",
help="Enable TeaCache optimization for faster generation.",
)
diffusion_group.add_argument(
"--teacache-thresh",
type=float,
default=0.2,
help="TeaCache threshold. Default: 0.2",
)
diffusion_group.add_argument(
"--attn-type",
type=str,
default="default",
choices=["default", "sage-attn", "sparse-videogen", "sparse-videogen2"],
help="Attention type for diffusion models. Default: default",
)
diffusion_group.add_argument(
"--linear-type",
type=str,
default="default",
choices=[
"default",
"trtllm-fp8-blockwise",
"trtllm-fp8-per-tensor",
"trtllm-nvfp4",
],
help="Linear type for quantization. Default: default",
)
diffusion_group.add_argument(
"--disable-torch-compile",
action="store_true",
help="Disable torch.compile optimization.",
)
diffusion_group.add_argument(
"--torch-compile-mode",
type=str,
default="default",
choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode. Default: default",
)
diffusion_group.add_argument(
"--dit-dp-size",
type=int,
default=1,
help="Data parallel size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--dit-tp-size",
type=int,
default=1,
help="Tensor parallel size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--dit-ulysses-size",
type=int,
default=1,
help="Ulysses parallel size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--dit-ring-size",
type=int,
default=1,
help="Ring parallel size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--dit-cfg-size",
type=int,
default=1,
help="CFG parallel size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--dit-fsdp-size",
type=int,
default=1,
help="FSDP size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--enable-async-cpu-offload",
action="store_true",
help="Enable async CPU offload for memory efficiency.",
)
args = parser.parse_args()
config = Config()
# Set the model path and served model name.
config.model_path = args.model_path
if args.served_model_name:
config.served_model_name = args.served_model_name
else:
# This becomes an `Option` on the Rust side
config.served_model_name = None
# Set modality
config.modality = args.modality
# Set the disaggregation mode.
config.disaggregation_mode = DisaggregationMode(args.disaggregation_mode)
# Set the appropriate default for the endpoint based on modality and disaggregation mode
if args.endpoint == "":
if Modality(args.modality) == Modality.VIDEO_DIFFUSION:
args.endpoint = DEFAULT_DIFFUSION_ENDPOINT
elif config.disaggregation_mode == DisaggregationMode.ENCODE:
args.endpoint = DEFAULT_ENCODE_ENDPOINT
elif config.disaggregation_mode == DisaggregationMode.PREFILL:
args.endpoint = DEFAULT_PREFILL_ENDPOINT
else:
# Decode and aggregated workers use "tensorrt_llm" component
args.endpoint = DEFAULT_ENDPOINT
endpoint = args.endpoint
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint
)
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.encode_endpoint = args.encode_endpoint
config.allowed_local_media_path = args.allowed_local_media_path
config.max_file_size_mb = args.max_file_size_mb
config.encoder_cache_capacity_gb = args.dyn_encoder_cache_capacity_gb
config.tensor_parallel_size = args.tensor_parallel_size
if args.pipeline_parallel_size is not None:
config.pipeline_parallel_size = args.pipeline_parallel_size
if args.expert_parallel_size is not None:
config.expert_parallel_size = args.expert_parallel_size
config.enable_attention_dp = args.enable_attention_dp
if args.gpus_per_node is not None:
config.gpus_per_node = args.gpus_per_node
if args.free_gpu_memory_fraction is not None:
config.free_gpu_memory_fraction = args.free_gpu_memory_fraction
config.max_batch_size = args.max_batch_size
config.max_num_tokens = args.max_num_tokens
config.max_seq_len = args.max_seq_len
config.max_beam_width = args.max_beam_width
config.kv_block_size = args.kv_block_size
config.extra_engine_args = args.extra_engine_args
config.override_engine_args = args.override_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics
config.reasoning_parser = args.dyn_reasoning_parser
config.tool_call_parser = args.dyn_tool_call_parser
config.dump_config_to = args.dump_config_to
config.dyn_endpoint_types = args.dyn_endpoint_types
config.discovery_backend = args.discovery_backend
config.request_plane = args.request_plane
config.event_plane = args.event_plane
config.enable_local_indexer = not args.durable_kv_events
# Derive use_kv_events from publish_events_and_metrics
config.use_kv_events = config.publish_events_and_metrics
config.connector = args.connector
# Handle custom jinja template path expansion (environment variables and home directory)
if args.custom_jinja_template:
expanded_template_path = os.path.expandvars(
os.path.expanduser(args.custom_jinja_template)
)
# Validate custom Jinja template file exists
if not os.path.isfile(expanded_template_path):
raise FileNotFoundError(
f"Custom Jinja template file not found: {expanded_template_path}"
)
config.custom_jinja_template = expanded_template_path
else:
config.custom_jinja_template = None
# Copy diffusion-specific args (only relevant for video_diffusion/image_diffusion)
config.output_dir = args.output_dir
config.default_height = args.default_height
config.default_width = args.default_width
config.default_num_frames = args.default_num_frames
config.default_num_inference_steps = args.default_num_inference_steps
config.default_guidance_scale = args.default_guidance_scale
config.enable_teacache = args.enable_teacache
config.teacache_thresh = args.teacache_thresh
config.attn_type = args.attn_type
config.linear_type = args.linear_type
config.disable_torch_compile = args.disable_torch_compile
config.torch_compile_mode = args.torch_compile_mode
config.dit_dp_size = args.dit_dp_size
config.dit_tp_size = args.dit_tp_size
config.dit_ulysses_size = args.dit_ulysses_size
config.dit_ring_size = args.dit_ring_size
config.dit_cfg_size = args.dit_cfg_size
config.dit_fsdp_size = args.dit_fsdp_size
config.enable_async_cpu_offload = args.enable_async_cpu_offload
return config
def deep_update(target: dict, source: dict) -> None:
"""
Recursively update nested dictionaries.
def deep_update(target: dict[str, Any], source: Mapping[str, Any]) -> None:
"""Recursively update nested dictionaries.
Args:
target: Dictionary to update
source: Dictionary with new values
target: Dictionary to update.
source: Dictionary with new values.
"""
for key, value in source.items():
if isinstance(value, dict) and key in target and isinstance(target[key], dict):
......
......@@ -20,8 +20,8 @@ import asyncio
import logging
from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.args import Config
from dynamo.trtllm.constants import Modality
from dynamo.trtllm.utils.trtllm_utils import Config
from dynamo.trtllm.workers.llm_worker import init_llm_worker
......@@ -40,7 +40,7 @@ async def init_worker(
"""
logging.info(f"Initializing worker with modality={config.modality}")
modality = Modality(config.modality)
modality = config.modality
if Modality.is_diffusion(modality):
if modality == Modality.VIDEO_DIFFUSION:
......
......@@ -45,6 +45,7 @@ from dynamo.llm import (
register_model,
)
from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.args import Config
from dynamo.trtllm.constants import DisaggregationMode
from dynamo.trtllm.engine import Backend, TensorRTLLMEngine, get_llm_engine
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
......@@ -54,7 +55,7 @@ from dynamo.trtllm.request_handlers.handlers import (
RequestHandlerConfig,
RequestHandlerFactory,
)
from dynamo.trtllm.utils.trtllm_utils import Config, deep_update
from dynamo.trtllm.utils.trtllm_utils import deep_update
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
......@@ -92,17 +93,17 @@ async def get_engine_runtime_config(
def build_kv_connector_config(config: Config):
if config.connector is not None:
if config.connector == "kvbm":
if config.connector:
if config.connector[0] == "kvbm":
return KvCacheConnectorConfig(
connector_module="kvbm.trtllm_integration.connector",
connector_scheduler_class="DynamoKVBMConnectorLeader",
connector_worker_class="DynamoKVBMConnectorWorker",
)
elif config.connector == "none":
elif config.connector[0] == "none":
return None
else:
logging.error(f"Invalid connector: {config.connector}")
logging.error(f"Invalid connector: {config.connector[0]}")
sys.exit(1)
return None
......@@ -138,7 +139,7 @@ async def init_llm_worker(
component = runtime.namespace(config.namespace).component(config.component)
# Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(config.model_path)
model_path = str(config.model)
if config.gpus_per_node is None:
gpus_per_node = device_count()
......@@ -151,7 +152,7 @@ async def init_llm_worker(
free_gpu_memory_fraction=config.free_gpu_memory_fraction
)
if config.connector is not None and "kvbm" in config.connector:
if config.has_connector("kvbm"):
kv_cache_config.enable_partial_reuse = False
dynamic_batch_config = DynamicBatchConfig(
......@@ -275,15 +276,13 @@ async def init_llm_worker(
if config.disaggregation_mode == DisaggregationMode.PREFILL:
model_type = ModelType.Prefill
else:
model_type = parse_endpoint_types(config.dyn_endpoint_types)
logging.info(
f"Registering model with endpoint types: {config.dyn_endpoint_types}"
)
model_type = parse_endpoint_types(config.endpoint_types)
logging.info(f"Registering model with endpoint types: {config.endpoint_types}")
# Warn if custom template provided but chat endpoint not enabled
if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types:
if config.custom_jinja_template and "chat" not in config.endpoint_types:
logging.warning(
"Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
"Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --endpoint-types. "
"The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
)
......@@ -298,12 +297,10 @@ async def init_llm_worker(
if modality == "multimodal":
engine_args["skip_tokenizer_init"] = False
model_config = AutoConfig.from_pretrained(
config.model_path, trust_remote_code=True
)
model_config = AutoConfig.from_pretrained(config.model, trust_remote_code=True)
multimodal_processor = MultimodalRequestProcessor(
model_type=model_config.model_type,
model_dir=config.model_path,
model_dir=config.model,
max_file_size_mb=config.max_file_size_mb,
tokenizer=tokenizer,
allowed_local_media_path=config.allowed_local_media_path,
......@@ -322,7 +319,7 @@ async def init_llm_worker(
)
# Prepare model name for metrics
model_name_for_metrics = config.served_model_name or config.model_path
model_name_for_metrics = config.served_model_name or config.model
# Construct Prometheus gauges directly; passed through to the engine and publisher
# via explicit parameters (no module-level global).
......@@ -357,8 +354,8 @@ async def init_llm_worker(
# Both parameters control the same thing: how many requests can be processed simultaneously
runtime_config.max_num_seqs = config.max_batch_size
runtime_config.max_num_batched_tokens = config.max_num_tokens
runtime_config.reasoning_parser = config.reasoning_parser
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.dyn_reasoning_parser
runtime_config.tool_call_parser = config.dyn_tool_call_parser
# Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
runtime_config.enable_local_indexer = (
config.enable_local_indexer
......@@ -386,7 +383,7 @@ async def init_llm_worker(
metrics_collector = None
if config.publish_events_and_metrics:
try:
model_name_for_metrics = config.served_model_name or config.model_path
model_name_for_metrics = config.served_model_name or config.model
metrics_collector = MetricsCollector(
{"model_name": model_name_for_metrics, "engine_type": "trtllm"}
)
......@@ -430,7 +427,7 @@ async def init_llm_worker(
metrics_collector=metrics_collector,
kv_block_size=config.kv_block_size,
shutdown_event=shutdown_event,
encoder_cache_capacity_gb=config.encoder_cache_capacity_gb,
encoder_cache_capacity_gb=config.multimodal_embedding_cache_capacity_gb,
)
# Register the model with runtime config
......@@ -441,7 +438,7 @@ async def init_llm_worker(
model_input,
model_type,
endpoint,
config.model_path,
config.model,
config.served_model_name,
kv_cache_block_size=config.kv_block_size,
runtime_config=runtime_config,
......@@ -457,8 +454,8 @@ async def init_llm_worker(
kv_listener = runtime.namespace(config.namespace).component(
config.component
)
# Use model_path as fallback if served_model_name is not provided
model_name_for_metrics = config.served_model_name or config.model_path
# Use model as fallback if served_model_name is not provided
model_name_for_metrics = config.served_model_name or config.model
metrics_labels = [
(
prometheus_names.labels.MODEL,
......
......@@ -12,7 +12,7 @@ import logging
from dynamo.llm import ModelInput, ModelType, register_model
from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.utils.trtllm_utils import Config
from dynamo.trtllm.args import Config
async def init_video_diffusion_worker(
......@@ -58,7 +58,7 @@ async def init_video_diffusion_worker(
discovery_backend=config.discovery_backend,
request_plane=config.request_plane,
event_plane=config.event_plane,
model_path=config.model_path,
model_path=config.model,
served_model_name=config.served_model_name,
output_dir=config.output_dir,
default_height=config.default_height,
......@@ -93,7 +93,7 @@ async def init_video_diffusion_worker(
handler = VideoGenerationHandler(component, engine, diffusion_config)
# Register the model with Dynamo's discovery system
model_name = config.served_model_name or config.model_path
model_name = config.served_model_name or config.model
# Use ModelType.Videos for video generation
if not hasattr(ModelType, "Videos"):
......@@ -111,7 +111,7 @@ async def init_video_diffusion_worker(
ModelInput.Text,
model_type,
endpoint,
config.model_path,
config.model,
model_name,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment