Unverified Commit a72bd22d authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: switch trtllm config modifier to --trtllm.* dynamic CLI flags (#7884)

parent ddaf0fb5
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import re
from typing import Tuple
......@@ -15,7 +14,6 @@ from dynamo.profiler.utils.config import (
break_arguments,
get_service_name_by_type,
get_worker_service_from_config,
parse_override_engine_args,
remove_valued_arguments,
setup_worker_service_resources,
update_image,
......@@ -46,6 +44,20 @@ DEFAULT_TRTLLM_AGG_CONFIG_PATH = resolve_deploy_path(
)
def _trtllm_flags(overrides: dict) -> list[str]:
"""Build a flat ``--trtllm.<dotted.key> <value>`` flag list."""
flags: list[str] = []
for key, value in overrides.items():
flags.append(f"--trtllm.{key}")
if value is None:
flags.append("none")
elif isinstance(value, bool):
flags.append(str(value).lower())
else:
flags.append(str(value))
return flags
class TrtllmConfigModifier(BaseConfigModifier):
BACKEND = "trtllm"
......@@ -116,28 +128,16 @@ class TrtllmConfigModifier(BaseConfigModifier):
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (prefill.yaml) which may contain user settings
# Check if user already has override-engine-args and merge with our changes
override_dict, args = parse_override_engine_args(args)
# Merge our overrides for converting prefill-only disagg to aggregated:
# - Disable enable_block_reuse (no KV reuse for prefill-only)
# - Enable overlap scheduler (disabled in prefill.yaml but needed for agg)
# - Remove cache_transceiver_config (not needed in agg mode)
if "kv_cache_config" not in override_dict or not isinstance(
override_dict["kv_cache_config"], dict
):
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = False
override_dict[
"disable_overlap_scheduler"
] = False # Enable overlap scheduler for agg
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
args = append_argument(
args,
_trtllm_flags(
{
"kv_cache_config.enable_block_reuse": False,
"disable_overlap_scheduler": False,
"cache_transceiver_config": None,
}
),
)
worker_service.extraPodSpec.mainContainer.args = args
......@@ -170,24 +170,15 @@ class TrtllmConfigModifier(BaseConfigModifier):
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (decode.yaml) which may contain user settings
# Check if user already has override-engine-args and merge with our changes
override_dict, args = parse_override_engine_args(args)
# Merge our overrides for converting decode-only disagg to aggregated:
# - Enable enable_block_reuse (to skip prefill in decode-only)
# - Remove cache_transceiver_config (not needed in agg mode)
if "kv_cache_config" not in override_dict or not isinstance(
override_dict["kv_cache_config"], dict
):
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = True
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
args = append_argument(
args,
_trtllm_flags(
{
"kv_cache_config.enable_block_reuse": True,
"cache_transceiver_config": None,
}
),
)
worker_service.extraPodSpec.mainContainer.args = args
......@@ -225,14 +216,7 @@ class TrtllmConfigModifier(BaseConfigModifier):
# Break arguments to handle both joined strings and lists
args = break_arguments(args)
# For TRT-LLM, we need to update the override-engine-args
# to set the tensor_parallel_size
override_dict, args = parse_override_engine_args(args)
# Add/update tensor_parallel_size in the override
override_dict["tensor_parallel_size"] = tp_size
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
args = append_argument(args, _trtllm_flags({"tensor_parallel_size": tp_size}))
worker_service.extraPodSpec.mainContainer.args = args
......@@ -328,7 +312,7 @@ class TrtllmConfigModifier(BaseConfigModifier):
) -> dict:
"""
Configure prefill-related limits for aggregated prefill runs.
For TRT-LLM we set these via --override-engine-args JSON:
For TRT-LLM we set these via ``--trtllm.*`` dynamic CLI flags:
- max_batch_size
- max_num_tokens
"""
......@@ -339,12 +323,15 @@ class TrtllmConfigModifier(BaseConfigModifier):
args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = break_arguments(args)
# Parse existing override-engine-args (if any) and update
override_dict, args = parse_override_engine_args(args)
override_dict["max_batch_size"] = int(max_batch_size)
override_dict["max_num_tokens"] = int(max_num_tokens)
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
args = append_argument(
args,
_trtllm_flags(
{
"max_batch_size": int(max_batch_size),
"max_num_tokens": int(max_num_tokens),
}
),
)
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
......@@ -13,8 +13,11 @@ DYNAMIC_FLAG_PREFIX = "--trtllm."
def infer_type(value: str) -> Any:
"""Infer the Python type of a CLI value string.
Tries int, float, bool, then falls back to string.
Tries None, int, float, bool, then falls back to string.
"""
# none / null
if value.lower() in ("none", "null"):
return None
# int
try:
return int(value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment