Unverified Commit 3882cba4 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: dynamic CLI flags for trtllm (#7335)

parent a43a6602
......@@ -4,6 +4,7 @@
"""Argument parsing and typed config for Dynamo TRT-LLM."""
import argparse
import json
import logging
import os
import sys
......@@ -17,6 +18,7 @@ from dynamo.common.configuration.groups.runtime_args import (
from dynamo.common.utils.runtime import parse_endpoint
from dynamo.trtllm.backend_args import DynamoTrtllmArgGroup, DynamoTrtllmConfig
from dynamo.trtllm.constants import DisaggregationMode, Modality
from dynamo.trtllm.dynamic_flags import parse_dynamic_flags
DEFAULT_ENDPOINT_COMPONENT = "tensorrt_llm"
DEFAULT_PREFILL_COMPONENT = "prefill"
......@@ -70,19 +72,44 @@ def _preprocess_for_encode_config(config: Config) -> Dict[str, Any]:
def parse_args(argv: Optional[Sequence[str]] = None) -> Config:
"""Parse command-line arguments for the TensorRT-LLM backend."""
"""Parse command-line arguments for the TensorRT-LLM backend.
In addition to the known flags, supports dynamic configuration flags
of the form ``--trtllm.<group>.<subgroup>.<key> <value>`` which are
collected into a nested dict and passed through ``override_engine_args``.
Cannot be combined with the explicit ``--override-engine-args`` flag.
"""
cli_args = list(argv) if argv is not None else sys.argv[1:]
parser = argparse.ArgumentParser(
description="Dynamo TensorRT-LLM worker configuration",
description="Dynamo TensorRT-LLM worker configuration\n\n"
"Dynamic engine configuration can be passed via dotted flags:\n"
" --trtllm.<group>.<key> <value>\n"
"Example:\n"
" --trtllm.kv_cache_config.free_gpu_memory_fraction 0.7\n"
"These flags are mutually exclusive with --override-engine-args.",
formatter_class=argparse.RawTextHelpFormatter,
)
DynamoRuntimeArgGroup().add_arguments(parser)
DynamoTrtllmArgGroup().add_arguments(parser)
parsed_args = parser.parse_args(cli_args)
parsed_args, remaining = parser.parse_known_args(cli_args)
config = Config.from_cli_args(parsed_args)
# Parse dynamic --trtllm.* flags from the remaining args
dynamic_overrides = parse_dynamic_flags(remaining)
if dynamic_overrides and config.override_engine_args:
logging.error(
"--override-engine-args and --trtllm.* dynamic flags are mutually "
"exclusive. Use one or the other."
)
sys.exit(1)
if dynamic_overrides:
config.override_engine_args = json.dumps(dynamic_overrides)
config.validate()
# TODO: move this to common configuration.
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Utilities for parsing dynamic ``--trtllm.*`` CLI flags into nested dicts."""
import logging
import sys
from typing import Any, Dict, List
DYNAMIC_FLAG_PREFIX = "--trtllm."
def infer_type(value: str) -> Any:
"""Infer the Python type of a CLI value string.
Tries int, float, bool, then falls back to string.
"""
# int
try:
return int(value)
except ValueError:
pass
# float
try:
return float(value)
except ValueError:
pass
# bool
if value.lower() == "true":
return True
if value.lower() == "false":
return False
# string
return value
def set_nested(d: Dict[str, Any], keys: List[str], value: Any) -> None:
"""Set a value in a nested dict, creating intermediate dicts as needed."""
current: Dict[str, Any] = d
for key in keys[:-1]:
existing = current.get(key)
if existing is None:
current[key] = {}
current = current[key]
elif not isinstance(existing, dict):
raise ValueError(
f"Conflicting dynamic flag path: key '{key}' is already set "
f"to a {type(existing).__name__} value"
)
else:
current = existing
current[keys[-1]] = value
def parse_dynamic_flags(remaining: List[str]) -> dict:
"""Parse ``--trtllm.a.b.c value`` flags into a nested dict.
Returns the nested dict built from all ``--trtllm.*`` flags.
Raises ``SystemExit`` if a flag has no value or if unknown flags remain.
"""
result: Dict[str, Any] = {}
i = 0
while i < len(remaining):
arg = remaining[i]
if not arg.startswith(DYNAMIC_FLAG_PREFIX):
logging.error("Unrecognized argument: %s", arg)
sys.exit(1)
dotted_key = arg[len(DYNAMIC_FLAG_PREFIX) :]
keys = dotted_key.split(".")
if not all(keys):
logging.error("Invalid dynamic flag (empty key segment): %s", arg)
sys.exit(1)
i += 1
if i >= len(remaining) or remaining[i].startswith("--"):
logging.error("Dynamic flag %s requires a value", arg)
sys.exit(1)
value = infer_type(remaining[i])
try:
set_nested(result, keys, value)
except ValueError as e:
logging.error("%s", e)
sys.exit(1)
i += 1
return result
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Tests for dynamic --trtllm.* flag parsing."""
import pytest
from dynamo.trtllm.dynamic_flags import infer_type, parse_dynamic_flags, set_nested
class TestInferType:
def test_int(self):
assert infer_type("42") == 42
assert isinstance(infer_type("42"), int)
def test_negative_int(self):
assert infer_type("-1") == -1
assert isinstance(infer_type("-1"), int)
def test_zero(self):
assert infer_type("0") == 0
assert isinstance(infer_type("0"), int)
def test_float(self):
assert infer_type("0.9") == 0.9
assert isinstance(infer_type("0.9"), float)
def test_negative_float(self):
assert infer_type("-0.5") == -0.5
assert isinstance(infer_type("-0.5"), float)
def test_bool_true(self):
assert infer_type("true") is True
assert infer_type("True") is True
assert infer_type("TRUE") is True
def test_bool_false(self):
assert infer_type("false") is False
assert infer_type("False") is False
assert infer_type("FALSE") is False
def test_string(self):
assert infer_type("hello") == "hello"
assert infer_type("/path/to/model") == "/path/to/model"
def test_string_not_bool(self):
# "yes", "no" should remain strings
assert infer_type("yes") == "yes"
assert infer_type("no") == "no"
def test_scientific_notation(self):
assert infer_type("1e3") == 1000.0
assert isinstance(infer_type("1e3"), float)
class TestSetNested:
def test_single_key(self):
d: dict[str, object] = {}
set_nested(d, ["key"], "value")
assert d == {"key": "value"}
def test_two_levels(self):
d: dict[str, object] = {}
set_nested(d, ["a", "b"], 42)
assert d == {"a": {"b": 42}}
def test_three_levels(self):
d: dict[str, object] = {}
set_nested(d, ["a", "b", "c"], True)
assert d == {"a": {"b": {"c": True}}}
def test_preserves_existing(self):
d = {"a": {"x": 1}}
set_nested(d, ["a", "y"], 2)
assert d == {"a": {"x": 1, "y": 2}}
def test_overwrites_leaf(self):
d = {"a": {"b": "old"}}
set_nested(d, ["a", "b"], "new")
assert d == {"a": {"b": "new"}}
class TestParseDynamicFlags:
def test_empty(self):
assert parse_dynamic_flags([]) == {}
def test_single_flat(self):
result = parse_dynamic_flags(["--trtllm.max_batch_size", "32"])
assert result == {"max_batch_size": 32}
def test_nested(self):
result = parse_dynamic_flags(
["--trtllm.kv_cache_config.free_gpu_memory_fraction", "0.7"]
)
assert result == {"kv_cache_config": {"free_gpu_memory_fraction": 0.7}}
def test_deeply_nested(self):
result = parse_dynamic_flags(["--trtllm.a.b.c.d", "hello"])
assert result == {"a": {"b": {"c": {"d": "hello"}}}}
def test_multiple_flags(self):
result = parse_dynamic_flags(
[
"--trtllm.kv_cache_config.free_gpu_memory_fraction",
"0.7",
"--trtllm.kv_cache_config.enable_block_reuse",
"false",
"--trtllm.tensor_parallel_size",
"4",
]
)
assert result == {
"kv_cache_config": {
"free_gpu_memory_fraction": 0.7,
"enable_block_reuse": False,
},
"tensor_parallel_size": 4,
}
def test_bool_values(self):
result = parse_dynamic_flags(
["--trtllm.some_flag", "true", "--trtllm.other_flag", "false"]
)
assert result == {"some_flag": True, "other_flag": False}
def test_string_values(self):
result = parse_dynamic_flags(["--trtllm.model", "/path/to/model"])
assert result == {"model": "/path/to/model"}
def test_unrecognized_arg_exits(self):
with pytest.raises(SystemExit):
parse_dynamic_flags(["--unknown-flag", "value"])
def test_missing_value_exits(self):
with pytest.raises(SystemExit):
parse_dynamic_flags(["--trtllm.some_key"])
def test_missing_value_next_is_flag_exits(self):
with pytest.raises(SystemExit):
parse_dynamic_flags(["--trtllm.key1", "--trtllm.key2", "val"])
def test_empty_key_segment_exits(self):
with pytest.raises(SystemExit):
parse_dynamic_flags(["--trtllm..bad", "value"])
def test_keys_preserved_as_is(self):
"""Keys are not transformed — hyphens, underscores, mixed case all pass through."""
result = parse_dynamic_flags(["--trtllm.My-Key_name.SubKey", "42"])
assert result == {"My-Key_name": {"SubKey": 42}}
def test_conflicting_path_exits(self):
"""Scalar then nested path on same key should fail gracefully."""
with pytest.raises(SystemExit):
parse_dynamic_flags(["--trtllm.a", "1", "--trtllm.a.b", "2"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment