"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "3057af00b6ceb41e8179c177d5446917a102bdba"
Unverified Commit 5a67b246 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: Migrate trtllm configuration (#6297)

parent 9a93eb75
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Argument parsing and typed config for Dynamo TRT-LLM."""
import argparse
import logging
import os
import sys
from typing import Any, Dict, Optional, Sequence
from dynamo.common.config_dump import register_encoder
from dynamo.common.configuration.groups.runtime_args import (
DynamoRuntimeArgGroup,
DynamoRuntimeConfig,
)
from dynamo.common.utils.runtime import parse_endpoint
from dynamo.trtllm.backend_args import DynamoTrtllmArgGroup, DynamoTrtllmConfig
from dynamo.trtllm.constants import DisaggregationMode, Modality
DEFAULT_ENDPOINT_COMPONENT = "tensorrt_llm"
DEFAULT_PREFILL_COMPONENT = "prefill"
DEFAULT_ENCODE_COMPONENT = "tensorrt_llm_encode"
DEFAULT_DIFFUSION_COMPONENT = "diffusion"
DEFAULT_ENDPOINT_NAME = "generate"
VALID_TRTLLM_CONNECTORS = {"none", "kvbm"}
class Config(DynamoRuntimeConfig, DynamoTrtllmConfig):
component: str
use_kv_events: bool
def validate(self) -> None:
DynamoRuntimeConfig.validate(self)
DynamoTrtllmConfig.validate(self)
# Derive use_kv_events from publish_events_and_metrics
self.use_kv_events = self.publish_events_and_metrics
# fix the connector as trtllm accepts only one connector and it should be in VALID_TRTLLM_CONNECTORS
# while the runtime args accepts a list of connectors
if self.connector:
if len(self.connector) > 1:
raise ValueError(
"TRT-LLM supports at most one connector entry. Use `--connector none` or `--connector kvbm`."
)
elif self.connector[0] not in VALID_TRTLLM_CONNECTORS:
source = (
f"DYN_CONNECTOR environment variable ('{os.environ['DYN_CONNECTOR']}')"
if "DYN_CONNECTOR" in os.environ
else f"shared runtime default ('{self.connector[0]}')"
)
logging.warning(
f"TRT-LLM does not support connector '{self.connector[0]}' (set via {source}). "
f"Supported connectors: {VALID_TRTLLM_CONNECTORS}. Falling back to 'none'."
)
self.connector = ["none"]
def has_connector(self, connector_name: str) -> bool:
return (
self.connector is not None
and len(self.connector) > 0
and connector_name == self.connector[0]
)
@register_encoder(Config)
def _preprocess_for_encode_config(config: Config) -> Dict[str, Any]:
return config.__dict__
def parse_args(argv: Optional[Sequence[str]] = None) -> Config:
"""Parse command-line arguments for the TensorRT-LLM backend."""
cli_args = list(argv) if argv is not None else sys.argv[1:]
parser = argparse.ArgumentParser(
description="Dynamo TensorRT-LLM worker configuration",
formatter_class=argparse.RawTextHelpFormatter,
)
DynamoRuntimeArgGroup().add_arguments(parser)
DynamoTrtllmArgGroup().add_arguments(parser)
parsed_args = parser.parse_args(cli_args)
config = Config.from_cli_args(parsed_args)
config.validate()
# TODO: move this to common configuration.
if config.custom_jinja_template:
expanded_template_path = os.path.expanduser(
os.path.expandvars(config.custom_jinja_template)
)
if not os.path.isfile(expanded_template_path):
raise FileNotFoundError(
f"Custom Jinja template file not found: {expanded_template_path}"
)
config.custom_jinja_template = expanded_template_path
else:
config.custom_jinja_template = None
endpoint = config.endpoint or _default_endpoint(
namespace=config.namespace,
modality=config.modality,
disaggregation_mode=config.disaggregation_mode,
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint
)
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
return config
def _default_endpoint(
namespace: str, modality: Modality, disaggregation_mode: DisaggregationMode
) -> str:
if modality == Modality.VIDEO_DIFFUSION:
component_name = DEFAULT_DIFFUSION_COMPONENT
elif disaggregation_mode == DisaggregationMode.ENCODE:
component_name = DEFAULT_ENCODE_COMPONENT
elif disaggregation_mode == DisaggregationMode.PREFILL:
component_name = DEFAULT_PREFILL_COMPONENT
else:
component_name = DEFAULT_ENDPOINT_COMPONENT
return f"dyn://{namespace}.{component_name}.{DEFAULT_ENDPOINT_NAME}"
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Dynamo TRT-LLM backend configuration ArgGroup."""
from typing import Optional
from tensorrt_llm.llmapi import BuildConfig
from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from . import __version__
from .constants import DisaggregationMode, Modality
DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
class DynamoTrtllmArgGroup(ArgGroup):
"""TensorRT-LLM-specific Dynamo wrapper configuration."""
def add_arguments(self, parser) -> None:
parser.add_argument(
"--version",
action="version",
version=f"Dynamo Backend TRTLLM {__version__}",
)
g = parser.add_argument_group("Dynamo TRT-LLM Options")
add_argument(
g,
flag_name="--model",
env_var="DYN_TRTLLM_MODEL",
default=DEFAULT_MODEL,
obsolete_flag="--model-path",
help=("Path to disk model or HuggingFace model identifier to load. "),
)
add_argument(
g,
flag_name="--served-model-name",
env_var="DYN_TRTLLM_SERVED_MODEL_NAME",
default=None,
help="Name to serve the model under. Defaults to deriving it from model path.",
)
add_argument(
g,
flag_name="--tensor-parallel-size",
env_var="DYN_TRTLLM_TENSOR_PARALLEL_SIZE",
default=1,
arg_type=int,
help="Tensor parallelism size.",
)
add_argument(
g,
flag_name="--pipeline-parallel-size",
env_var="DYN_TRTLLM_PIPELINE_PARALLEL_SIZE",
default=1,
arg_type=int,
help="Pipeline parallelism size.",
)
add_argument(
g,
flag_name="--expert-parallel-size",
env_var="DYN_TRTLLM_EXPERT_PARALLEL_SIZE",
default=None,
arg_type=int,
help="Expert parallelism size.",
)
add_negatable_bool_argument(
g,
flag_name="--enable-attention-dp",
env_var="DYN_TRTLLM_ENABLE_ATTENTION_DP",
default=False,
help="Enable attention data parallelism. When enabled, attention_dp_size equals tensor_parallel_size.",
)
add_argument(
g,
flag_name="--kv-block-size",
env_var="DYN_TRTLLM_KV_BLOCK_SIZE",
default=32,
arg_type=int,
help="Size of a KV cache block.",
)
add_argument(
g,
flag_name="--gpus-per-node",
env_var="DYN_TRTLLM_GPUS_PER_NODE",
default=None,
arg_type=int,
help="Number of GPUs per node. If not provided, inferred from the environment.",
)
add_argument(
g,
flag_name="--max-batch-size",
env_var="DYN_TRTLLM_MAX_BATCH_SIZE",
default=BuildConfig.model_fields["max_batch_size"].default,
arg_type=int,
help="Maximum number of requests that the engine can schedule.",
)
add_argument(
g,
flag_name="--max-num-tokens",
env_var="DYN_TRTLLM_MAX_NUM_TOKENS",
default=BuildConfig.model_fields["max_num_tokens"].default,
arg_type=int,
help="Maximum number of batched input tokens after padding is removed in each batch.",
)
add_argument(
g,
flag_name="--max-seq-len",
env_var="DYN_TRTLLM_MAX_SEQ_LEN",
default=BuildConfig.model_fields["max_seq_len"].default,
arg_type=int,
help="Maximum total length of one request, including prompt and outputs. If unspecified, the value is deduced from the model config.",
)
add_argument(
g,
flag_name="--max-beam-width",
env_var="DYN_TRTLLM_MAX_BEAM_WIDTH",
default=BuildConfig.model_fields["max_beam_width"].default,
arg_type=int,
help="Maximum number of beams for beam search decoding.",
)
add_argument(
g,
flag_name="--free-gpu-memory-fraction",
env_var="DYN_TRTLLM_FREE_GPU_MEMORY_FRACTION",
default=0.9,
arg_type=float,
help="Free GPU memory fraction reserved for KV Cache, after model weights and buffers are allocated.",
)
add_argument(
g,
flag_name="--extra-engine-args",
env_var="DYN_TRTLLM_EXTRA_ENGINE_ARGS",
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
)
add_argument(
g,
flag_name="--override-engine-args",
env_var="DYN_TRTLLM_OVERRIDE_ENGINE_ARGS",
default="",
help="Python dictionary string to override specific engine arguments from the YAML file. "
'Example: \'{"tensor_parallel_size": 2, "kv_cache_config": {"enable_block_reuse": false}}\'',
)
add_negatable_bool_argument(
g,
flag_name="--publish-events-and-metrics",
env_var="DYN_TRTLLM_PUBLISH_EVENTS_AND_METRICS",
default=False,
help="If set, publish events and metrics to Dynamo components.",
)
add_argument(
g,
flag_name="--disaggregation-mode",
env_var="DYN_TRTLLM_DISAGGREGATION_MODE",
default=DisaggregationMode.AGGREGATED.value,
choices=[mode.value for mode in DisaggregationMode],
help="Mode to use for disaggregation.",
)
add_argument(
g,
flag_name="--modality",
env_var="DYN_TRTLLM_MODALITY",
default=Modality.TEXT.value,
choices=[m.value for m in Modality],
help="Modality to use for the model.",
)
add_argument(
g,
flag_name="--encode-endpoint",
env_var="DYN_TRTLLM_ENCODE_ENDPOINT",
default="",
help="Endpoint (in 'dyn://namespace.component.endpoint' format) for the encode worker.",
)
add_argument(
g,
flag_name="--allowed-local-media-path",
env_var="DYN_TRTLLM_ALLOWED_LOCAL_MEDIA_PATH",
default="",
help="Path to a directory that is allowed to be accessed by the model.",
)
add_argument(
g,
flag_name="--max-file-size-mb",
env_var="DYN_TRTLLM_MAX_FILE_SIZE_MB",
default=50,
arg_type=int,
help="Maximum size of downloadable embedding files/Image URLs.",
)
diffusion_group = parser.add_argument_group(
"Diffusion Options [Experimental]",
"Options for video_diffusion modality",
)
add_argument(
diffusion_group,
flag_name="--output-dir",
env_var="DYN_TRTLLM_OUTPUT_DIR",
default="/tmp/dynamo_videos",
help="Directory to store generated videos/images.",
)
add_argument(
diffusion_group,
flag_name="--default-height",
env_var="DYN_TRTLLM_DEFAULT_HEIGHT",
default=480,
arg_type=int,
help="Default video/image height in pixels.",
)
add_argument(
diffusion_group,
flag_name="--default-width",
env_var="DYN_TRTLLM_DEFAULT_WIDTH",
default=832,
arg_type=int,
help="Default video/image width in pixels.",
)
add_argument(
diffusion_group,
flag_name="--default-num-frames",
env_var="DYN_TRTLLM_DEFAULT_NUM_FRAMES",
default=81,
arg_type=int,
help="Default number of frames for video generation.",
)
add_argument(
diffusion_group,
flag_name="--default-num-inference-steps",
env_var="DYN_TRTLLM_DEFAULT_NUM_INFERENCE_STEPS",
default=50,
arg_type=int,
help="Default number of inference steps.",
)
add_argument(
diffusion_group,
flag_name="--default-guidance-scale",
env_var="DYN_TRTLLM_DEFAULT_GUIDANCE_SCALE",
default=5.0,
arg_type=float,
help="Default CFG guidance scale.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-teacache",
env_var="DYN_TRTLLM_ENABLE_TEACACHE",
default=False,
help="Enable TeaCache optimization for faster generation.",
)
add_argument(
diffusion_group,
flag_name="--teacache-thresh",
env_var="DYN_TRTLLM_TEACACHE_THRESH",
default=0.2,
arg_type=float,
help="TeaCache threshold.",
)
add_argument(
diffusion_group,
flag_name="--attn-type",
env_var="DYN_TRTLLM_ATTN_TYPE",
default="default",
choices=["default", "sage-attn", "sparse-videogen", "sparse-videogen2"],
help="Attention type for diffusion models.",
)
add_argument(
diffusion_group,
flag_name="--linear-type",
env_var="DYN_TRTLLM_LINEAR_TYPE",
default="default",
choices=[
"default",
"trtllm-fp8-blockwise",
"trtllm-fp8-per-tensor",
"trtllm-nvfp4",
],
help="Linear type for quantization.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--disable-torch-compile",
env_var="DYN_TRTLLM_DISABLE_TORCH_COMPILE",
default=False,
help="Disable torch.compile optimization.",
)
add_argument(
diffusion_group,
flag_name="--torch-compile-mode",
env_var="DYN_TRTLLM_TORCH_COMPILE_MODE",
default="default",
choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode.",
)
add_argument(
diffusion_group,
flag_name="--dit-dp-size",
env_var="DYN_TRTLLM_DIT_DP_SIZE",
default=1,
arg_type=int,
help="Data parallel size for DiT.",
)
add_argument(
diffusion_group,
flag_name="--dit-tp-size",
env_var="DYN_TRTLLM_DIT_TP_SIZE",
default=1,
arg_type=int,
help="Tensor parallel size for DiT.",
)
add_argument(
diffusion_group,
flag_name="--dit-ulysses-size",
env_var="DYN_TRTLLM_DIT_ULYSSES_SIZE",
default=1,
arg_type=int,
help="Ulysses parallel size for DiT.",
)
add_argument(
diffusion_group,
flag_name="--dit-ring-size",
env_var="DYN_TRTLLM_DIT_RING_SIZE",
default=1,
arg_type=int,
help="Ring parallel size for DiT.",
)
add_argument(
diffusion_group,
flag_name="--dit-cfg-size",
env_var="DYN_TRTLLM_DIT_CFG_SIZE",
default=1,
arg_type=int,
help="CFG parallel size for DiT.",
)
add_argument(
diffusion_group,
flag_name="--dit-fsdp-size",
env_var="DYN_TRTLLM_DIT_FSDP_SIZE",
default=1,
arg_type=int,
help="FSDP size for DiT.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-async-cpu-offload",
env_var="DYN_TRTLLM_ENABLE_ASYNC_CPU_OFFLOAD",
default=False,
help="Enable async CPU offload for memory efficiency.",
)
class DynamoTrtllmConfig(ConfigBase):
"""Configuration for Dynamo TRT-LLM backend-specific options."""
model: str
served_model_name: Optional[str] = None
tensor_parallel_size: int
pipeline_parallel_size: int
expert_parallel_size: Optional[int]
enable_attention_dp: bool
kv_block_size: int
gpus_per_node: Optional[int] = None
max_batch_size: int
max_num_tokens: int
max_seq_len: int
max_beam_width: int
free_gpu_memory_fraction: float
extra_engine_args: str
override_engine_args: str
publish_events_and_metrics: bool
disaggregation_mode: DisaggregationMode
modality: Modality
encode_endpoint: str
allowed_local_media_path: str
max_file_size_mb: int
output_dir: str
default_height: int
default_width: int
default_num_frames: int
default_num_inference_steps: int
default_guidance_scale: float
enable_teacache: bool
teacache_thresh: float
attn_type: str
linear_type: str
disable_torch_compile: bool
torch_compile_mode: str
dit_dp_size: int
dit_tp_size: int
dit_ulysses_size: int
dit_ring_size: int
dit_cfg_size: int
dit_fsdp_size: int
enable_async_cpu_offload: bool
def validate(self) -> None:
if isinstance(self.disaggregation_mode, str):
self.disaggregation_mode = DisaggregationMode(self.disaggregation_mode)
if isinstance(self.modality, str):
self.modality = Modality(self.modality)
if not self.served_model_name:
self.served_model_name = None
...@@ -20,14 +20,14 @@ import uvloop ...@@ -20,14 +20,14 @@ import uvloop
from dynamo.common.utils.runtime import create_runtime from dynamo.common.utils.runtime import create_runtime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.utils.trtllm_utils import cmd_line_args from dynamo.trtllm.args import parse_args
from dynamo.trtllm.workers import init_worker from dynamo.trtllm.workers import init_worker
configure_dynamo_logging() configure_dynamo_logging()
async def worker(): async def worker():
config = cmd_line_args() config = parse_args()
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
runtime, _ = create_runtime( runtime, _ = create_runtime(
......
...@@ -16,8 +16,9 @@ if not torch.cuda.is_available(): ...@@ -16,8 +16,9 @@ if not torch.cuda.is_available():
allow_module_level=True, allow_module_level=True,
) )
from dynamo.trtllm.args import Config, parse_args
from dynamo.trtllm.tests.conftest import make_cli_args_fixture from dynamo.trtllm.tests.conftest import make_cli_args_fixture
from dynamo.trtllm.utils.trtllm_utils import cmd_line_args from dynamo.trtllm.utils.trtllm_utils import deep_update
# Get path relative to this test file # Get path relative to this test file
REPO_ROOT = Path(__file__).resolve().parents[5] REPO_ROOT = Path(__file__).resolve().parents[5]
...@@ -51,13 +52,13 @@ def test_custom_jinja_template_invalid_path(mock_trtllm_cli): ...@@ -51,13 +52,13 @@ def test_custom_jinja_template_invalid_path(mock_trtllm_cli):
FileNotFoundError, FileNotFoundError,
match=re.escape(f"Custom Jinja template file not found: {invalid_path}"), match=re.escape(f"Custom Jinja template file not found: {invalid_path}"),
): ):
cmd_line_args() # This will read in from argv parse_args() # Reads from argv set by fixture
def test_custom_jinja_template_valid_path(mock_trtllm_cli): def test_custom_jinja_template_valid_path(mock_trtllm_cli):
"""Test that valid absolute path is stored correctly.""" """Test that valid absolute path is stored correctly."""
mock_trtllm_cli(model="Qwen/Qwen3-0.6B", custom_jinja_template=JINJA_TEMPLATE_PATH) mock_trtllm_cli(model="Qwen/Qwen3-0.6B", custom_jinja_template=JINJA_TEMPLATE_PATH)
config = cmd_line_args() config = parse_args()
assert config.custom_jinja_template == JINJA_TEMPLATE_PATH, ( assert config.custom_jinja_template == JINJA_TEMPLATE_PATH, (
f"Expected custom_jinja_template value to be {JINJA_TEMPLATE_PATH}, " f"Expected custom_jinja_template value to be {JINJA_TEMPLATE_PATH}, "
...@@ -73,10 +74,93 @@ def test_custom_jinja_template_env_var_expansion(monkeypatch, mock_trtllm_cli): ...@@ -73,10 +74,93 @@ def test_custom_jinja_template_env_var_expansion(monkeypatch, mock_trtllm_cli):
cli_path = "$JINJA_DIR/custom_template.jinja" cli_path = "$JINJA_DIR/custom_template.jinja"
mock_trtllm_cli(model="Qwen/Qwen3-0.6B", custom_jinja_template=cli_path) mock_trtllm_cli(model="Qwen/Qwen3-0.6B", custom_jinja_template=cli_path)
config = cmd_line_args() config = parse_args()
assert "$JINJA_DIR" not in config.custom_jinja_template assert "$JINJA_DIR" not in config.custom_jinja_template
assert config.custom_jinja_template == JINJA_TEMPLATE_PATH, ( assert config.custom_jinja_template == JINJA_TEMPLATE_PATH, (
f"Expected custom_jinja_template value to be {JINJA_TEMPLATE_PATH}, " f"Expected custom_jinja_template value to be {JINJA_TEMPLATE_PATH}, "
f"got {config.custom_jinja_template}" f"got {config.custom_jinja_template}"
) )
# ---- Tests for trtllm/args.py (Config, parse_args) ----
def test_parse_args_returns_config_with_expected_attrs(monkeypatch):
"""parse_args returns a Config instance with model, component, and endpoint set."""
monkeypatch.delenv("DYN_NAMESPACE", raising=False)
monkeypatch.delenv("DYN_TRTLLM_MODEL", raising=False)
config = parse_args(["--namespace", "testns", "--model-path", "Qwen/Qwen3-0.6B"])
assert isinstance(config, Config)
assert config.model == "Qwen/Qwen3-0.6B"
assert config.namespace == "testns"
assert config.component == "tensorrt_llm"
assert config.endpoint == "generate"
def test_config_use_kv_events_derived_from_publish_events(monkeypatch):
"""Config.validate sets use_kv_events from publish_events_and_metrics."""
monkeypatch.delenv("DYN_TRTLLM_PUBLISH_EVENTS", raising=False)
config = parse_args(["--publish-events"])
assert config.publish_events_and_metrics is True
assert config.use_kv_events is True
config_off = parse_args(["--no-publish-events"])
assert config_off.publish_events_and_metrics is False
assert config_off.use_kv_events is False
def test_config_has_connector(monkeypatch):
"""Config.has_connector returns True only for the single configured connector."""
monkeypatch.delenv("DYN_CONNECTOR", raising=False)
config_none = parse_args(["--connector", "none"])
assert config_none.has_connector("none") is True
assert config_none.has_connector("kvbm") is False
config_kvbm = parse_args(["--connector", "kvbm"])
assert config_kvbm.has_connector("kvbm") is True
assert config_kvbm.has_connector("none") is False
def test_config_multiple_connectors_fails(monkeypatch):
"""Config.validate fails if multiple connectors are provided."""
monkeypatch.delenv("DYN_CONNECTOR", raising=False)
with pytest.raises(
ValueError,
match="TRT-LLM supports at most one connector entry. Use `--connector none` or `--connector kvbm`.",
):
parse_args(["--connector", "none", "kvbm"])
# ---- Tests for trtllm_utils.deep_update ----
def test_deep_update_nested_merge():
"""deep_update merges nested dicts without removing existing keys."""
target = {"a": 1, "b": {"x": 10, "y": 20}}
source = {"b": {"y": 21, "z": 30}}
deep_update(target, source)
assert target == {"a": 1, "b": {"x": 10, "y": 21, "z": 30}}
def test_deep_update_overwrites_scalar_with_value():
"""deep_update overwrites a key with a non-dict value."""
target = {"a": 1, "b": {"x": 10}}
source = {"a": 2, "b": 99}
deep_update(target, source)
assert target == {"a": 2, "b": 99}
def test_deep_update_empty_source_unchanged():
"""deep_update with empty source leaves target unchanged."""
target = {"a": 1, "b": {"x": 10}}
deep_update(target, {})
assert target == {"a": 1, "b": {"x": 10}}
def test_deep_update_adds_new_keys():
"""deep_update adds new keys from source that are not in target."""
target = {"a": 1}
source = {"b": 2, "c": {"nested": 3}}
deep_update(target, source)
assert target == {"a": 1, "b": 2, "c": {"nested": 3}}
...@@ -20,8 +20,8 @@ import asyncio ...@@ -20,8 +20,8 @@ import asyncio
import logging import logging
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.args import Config
from dynamo.trtllm.constants import Modality from dynamo.trtllm.constants import Modality
from dynamo.trtllm.utils.trtllm_utils import Config
from dynamo.trtllm.workers.llm_worker import init_llm_worker from dynamo.trtllm.workers.llm_worker import init_llm_worker
...@@ -40,7 +40,7 @@ async def init_worker( ...@@ -40,7 +40,7 @@ async def init_worker(
""" """
logging.info(f"Initializing worker with modality={config.modality}") logging.info(f"Initializing worker with modality={config.modality}")
modality = Modality(config.modality) modality = config.modality
if Modality.is_diffusion(modality): if Modality.is_diffusion(modality):
if modality == Modality.VIDEO_DIFFUSION: if modality == Modality.VIDEO_DIFFUSION:
......
...@@ -45,6 +45,7 @@ from dynamo.llm import ( ...@@ -45,6 +45,7 @@ from dynamo.llm import (
register_model, register_model,
) )
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.args import Config
from dynamo.trtllm.constants import DisaggregationMode from dynamo.trtllm.constants import DisaggregationMode
from dynamo.trtllm.engine import Backend, TensorRTLLMEngine, get_llm_engine from dynamo.trtllm.engine import Backend, TensorRTLLMEngine, get_llm_engine
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
...@@ -54,7 +55,7 @@ from dynamo.trtllm.request_handlers.handlers import ( ...@@ -54,7 +55,7 @@ from dynamo.trtllm.request_handlers.handlers import (
RequestHandlerConfig, RequestHandlerConfig,
RequestHandlerFactory, RequestHandlerFactory,
) )
from dynamo.trtllm.utils.trtllm_utils import Config, deep_update from dynamo.trtllm.utils.trtllm_utils import deep_update
# Default buffer size for kv cache events. # Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024 DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
...@@ -92,17 +93,17 @@ async def get_engine_runtime_config( ...@@ -92,17 +93,17 @@ async def get_engine_runtime_config(
def build_kv_connector_config(config: Config): def build_kv_connector_config(config: Config):
if config.connector is not None: if config.connector:
if config.connector == "kvbm": if config.connector[0] == "kvbm":
return KvCacheConnectorConfig( return KvCacheConnectorConfig(
connector_module="kvbm.trtllm_integration.connector", connector_module="kvbm.trtllm_integration.connector",
connector_scheduler_class="DynamoKVBMConnectorLeader", connector_scheduler_class="DynamoKVBMConnectorLeader",
connector_worker_class="DynamoKVBMConnectorWorker", connector_worker_class="DynamoKVBMConnectorWorker",
) )
elif config.connector == "none": elif config.connector[0] == "none":
return None return None
else: else:
logging.error(f"Invalid connector: {config.connector}") logging.error(f"Invalid connector: {config.connector[0]}")
sys.exit(1) sys.exit(1)
return None return None
...@@ -138,7 +139,7 @@ async def init_llm_worker( ...@@ -138,7 +139,7 @@ async def init_llm_worker(
component = runtime.namespace(config.namespace).component(config.component) component = runtime.namespace(config.namespace).component(config.component)
# Convert model path to Path object if it's a local path, otherwise keep as string # Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(config.model_path) model_path = str(config.model)
if config.gpus_per_node is None: if config.gpus_per_node is None:
gpus_per_node = device_count() gpus_per_node = device_count()
...@@ -151,7 +152,7 @@ async def init_llm_worker( ...@@ -151,7 +152,7 @@ async def init_llm_worker(
free_gpu_memory_fraction=config.free_gpu_memory_fraction free_gpu_memory_fraction=config.free_gpu_memory_fraction
) )
if config.connector is not None and "kvbm" in config.connector: if config.has_connector("kvbm"):
kv_cache_config.enable_partial_reuse = False kv_cache_config.enable_partial_reuse = False
dynamic_batch_config = DynamicBatchConfig( dynamic_batch_config = DynamicBatchConfig(
...@@ -275,15 +276,13 @@ async def init_llm_worker( ...@@ -275,15 +276,13 @@ async def init_llm_worker(
if config.disaggregation_mode == DisaggregationMode.PREFILL: if config.disaggregation_mode == DisaggregationMode.PREFILL:
model_type = ModelType.Prefill model_type = ModelType.Prefill
else: else:
model_type = parse_endpoint_types(config.dyn_endpoint_types) model_type = parse_endpoint_types(config.endpoint_types)
logging.info( logging.info(f"Registering model with endpoint types: {config.endpoint_types}")
f"Registering model with endpoint types: {config.dyn_endpoint_types}"
)
# Warn if custom template provided but chat endpoint not enabled # Warn if custom template provided but chat endpoint not enabled
if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types: if config.custom_jinja_template and "chat" not in config.endpoint_types:
logging.warning( logging.warning(
"Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. " "Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --endpoint-types. "
"The chat template will be loaded but the /v1/chat/completions endpoint will not be available." "The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
) )
...@@ -298,12 +297,10 @@ async def init_llm_worker( ...@@ -298,12 +297,10 @@ async def init_llm_worker(
if modality == "multimodal": if modality == "multimodal":
engine_args["skip_tokenizer_init"] = False engine_args["skip_tokenizer_init"] = False
model_config = AutoConfig.from_pretrained( model_config = AutoConfig.from_pretrained(config.model, trust_remote_code=True)
config.model_path, trust_remote_code=True
)
multimodal_processor = MultimodalRequestProcessor( multimodal_processor = MultimodalRequestProcessor(
model_type=model_config.model_type, model_type=model_config.model_type,
model_dir=config.model_path, model_dir=config.model,
max_file_size_mb=config.max_file_size_mb, max_file_size_mb=config.max_file_size_mb,
tokenizer=tokenizer, tokenizer=tokenizer,
allowed_local_media_path=config.allowed_local_media_path, allowed_local_media_path=config.allowed_local_media_path,
...@@ -322,7 +319,7 @@ async def init_llm_worker( ...@@ -322,7 +319,7 @@ async def init_llm_worker(
) )
# Prepare model name for metrics # Prepare model name for metrics
model_name_for_metrics = config.served_model_name or config.model_path model_name_for_metrics = config.served_model_name or config.model
# Construct Prometheus gauges directly; passed through to the engine and publisher # Construct Prometheus gauges directly; passed through to the engine and publisher
# via explicit parameters (no module-level global). # via explicit parameters (no module-level global).
...@@ -357,8 +354,8 @@ async def init_llm_worker( ...@@ -357,8 +354,8 @@ async def init_llm_worker(
# Both parameters control the same thing: how many requests can be processed simultaneously # Both parameters control the same thing: how many requests can be processed simultaneously
runtime_config.max_num_seqs = config.max_batch_size runtime_config.max_num_seqs = config.max_batch_size
runtime_config.max_num_batched_tokens = config.max_num_tokens runtime_config.max_num_batched_tokens = config.max_num_tokens
runtime_config.reasoning_parser = config.reasoning_parser runtime_config.reasoning_parser = config.dyn_reasoning_parser
runtime_config.tool_call_parser = config.tool_call_parser runtime_config.tool_call_parser = config.dyn_tool_call_parser
# Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer # Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
runtime_config.enable_local_indexer = ( runtime_config.enable_local_indexer = (
config.enable_local_indexer config.enable_local_indexer
...@@ -386,7 +383,7 @@ async def init_llm_worker( ...@@ -386,7 +383,7 @@ async def init_llm_worker(
metrics_collector = None metrics_collector = None
if config.publish_events_and_metrics: if config.publish_events_and_metrics:
try: try:
model_name_for_metrics = config.served_model_name or config.model_path model_name_for_metrics = config.served_model_name or config.model
metrics_collector = MetricsCollector( metrics_collector = MetricsCollector(
{"model_name": model_name_for_metrics, "engine_type": "trtllm"} {"model_name": model_name_for_metrics, "engine_type": "trtllm"}
) )
...@@ -430,7 +427,7 @@ async def init_llm_worker( ...@@ -430,7 +427,7 @@ async def init_llm_worker(
metrics_collector=metrics_collector, metrics_collector=metrics_collector,
kv_block_size=config.kv_block_size, kv_block_size=config.kv_block_size,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
encoder_cache_capacity_gb=config.encoder_cache_capacity_gb, encoder_cache_capacity_gb=config.multimodal_embedding_cache_capacity_gb,
) )
# Register the model with runtime config # Register the model with runtime config
...@@ -441,7 +438,7 @@ async def init_llm_worker( ...@@ -441,7 +438,7 @@ async def init_llm_worker(
model_input, model_input,
model_type, model_type,
endpoint, endpoint,
config.model_path, config.model,
config.served_model_name, config.served_model_name,
kv_cache_block_size=config.kv_block_size, kv_cache_block_size=config.kv_block_size,
runtime_config=runtime_config, runtime_config=runtime_config,
...@@ -457,8 +454,8 @@ async def init_llm_worker( ...@@ -457,8 +454,8 @@ async def init_llm_worker(
kv_listener = runtime.namespace(config.namespace).component( kv_listener = runtime.namespace(config.namespace).component(
config.component config.component
) )
# Use model_path as fallback if served_model_name is not provided # Use model as fallback if served_model_name is not provided
model_name_for_metrics = config.served_model_name or config.model_path model_name_for_metrics = config.served_model_name or config.model
metrics_labels = [ metrics_labels = [
( (
prometheus_names.labels.MODEL, prometheus_names.labels.MODEL,
......
...@@ -12,7 +12,7 @@ import logging ...@@ -12,7 +12,7 @@ import logging
from dynamo.llm import ModelInput, ModelType, register_model from dynamo.llm import ModelInput, ModelType, register_model
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.utils.trtllm_utils import Config from dynamo.trtllm.args import Config
async def init_video_diffusion_worker( async def init_video_diffusion_worker(
...@@ -58,7 +58,7 @@ async def init_video_diffusion_worker( ...@@ -58,7 +58,7 @@ async def init_video_diffusion_worker(
discovery_backend=config.discovery_backend, discovery_backend=config.discovery_backend,
request_plane=config.request_plane, request_plane=config.request_plane,
event_plane=config.event_plane, event_plane=config.event_plane,
model_path=config.model_path, model_path=config.model,
served_model_name=config.served_model_name, served_model_name=config.served_model_name,
output_dir=config.output_dir, output_dir=config.output_dir,
default_height=config.default_height, default_height=config.default_height,
...@@ -93,7 +93,7 @@ async def init_video_diffusion_worker( ...@@ -93,7 +93,7 @@ async def init_video_diffusion_worker(
handler = VideoGenerationHandler(component, engine, diffusion_config) handler = VideoGenerationHandler(component, engine, diffusion_config)
# Register the model with Dynamo's discovery system # Register the model with Dynamo's discovery system
model_name = config.served_model_name or config.model_path model_name = config.served_model_name or config.model
# Use ModelType.Videos for video generation # Use ModelType.Videos for video generation
if not hasattr(ModelType, "Videos"): if not hasattr(ModelType, "Videos"):
...@@ -111,7 +111,7 @@ async def init_video_diffusion_worker( ...@@ -111,7 +111,7 @@ async def init_video_diffusion_worker(
ModelInput.Text, ModelInput.Text,
model_type, model_type,
endpoint, endpoint,
config.model_path, config.model,
model_name, model_name,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment