Unverified Commit a72bd22d authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: switch trtllm config modifier to --trtllm.* dynamic CLI flags (#7884)

parent ddaf0fb5
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
import logging import logging
import re import re
from typing import Tuple from typing import Tuple
...@@ -15,7 +14,6 @@ from dynamo.profiler.utils.config import ( ...@@ -15,7 +14,6 @@ from dynamo.profiler.utils.config import (
break_arguments, break_arguments,
get_service_name_by_type, get_service_name_by_type,
get_worker_service_from_config, get_worker_service_from_config,
parse_override_engine_args,
remove_valued_arguments, remove_valued_arguments,
setup_worker_service_resources, setup_worker_service_resources,
update_image, update_image,
...@@ -46,6 +44,20 @@ DEFAULT_TRTLLM_AGG_CONFIG_PATH = resolve_deploy_path( ...@@ -46,6 +44,20 @@ DEFAULT_TRTLLM_AGG_CONFIG_PATH = resolve_deploy_path(
) )
def _trtllm_flags(overrides: dict) -> list[str]:
"""Build a flat ``--trtllm.<dotted.key> <value>`` flag list."""
flags: list[str] = []
for key, value in overrides.items():
flags.append(f"--trtllm.{key}")
if value is None:
flags.append("none")
elif isinstance(value, bool):
flags.append(str(value).lower())
else:
flags.append(str(value))
return flags
class TrtllmConfigModifier(BaseConfigModifier): class TrtllmConfigModifier(BaseConfigModifier):
BACKEND = "trtllm" BACKEND = "trtllm"
...@@ -116,28 +128,16 @@ class TrtllmConfigModifier(BaseConfigModifier): ...@@ -116,28 +128,16 @@ class TrtllmConfigModifier(BaseConfigModifier):
args = remove_valued_arguments(args, "--disaggregation-mode") args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy") args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (prefill.yaml) which may contain user settings args = append_argument(
# Check if user already has override-engine-args and merge with our changes args,
override_dict, args = parse_override_engine_args(args) _trtllm_flags(
{
# Merge our overrides for converting prefill-only disagg to aggregated: "kv_cache_config.enable_block_reuse": False,
# - Disable enable_block_reuse (no KV reuse for prefill-only) "disable_overlap_scheduler": False,
# - Enable overlap scheduler (disabled in prefill.yaml but needed for agg) "cache_transceiver_config": None,
# - Remove cache_transceiver_config (not needed in agg mode) }
if "kv_cache_config" not in override_dict or not isinstance( ),
override_dict["kv_cache_config"], dict )
):
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = False
override_dict[
"disable_overlap_scheduler"
] = False # Enable overlap scheduler for agg
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args worker_service.extraPodSpec.mainContainer.args = args
...@@ -170,24 +170,15 @@ class TrtllmConfigModifier(BaseConfigModifier): ...@@ -170,24 +170,15 @@ class TrtllmConfigModifier(BaseConfigModifier):
args = remove_valued_arguments(args, "--disaggregation-mode") args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy") args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (decode.yaml) which may contain user settings args = append_argument(
# Check if user already has override-engine-args and merge with our changes args,
override_dict, args = parse_override_engine_args(args) _trtllm_flags(
{
# Merge our overrides for converting decode-only disagg to aggregated: "kv_cache_config.enable_block_reuse": True,
# - Enable enable_block_reuse (to skip prefill in decode-only) "cache_transceiver_config": None,
# - Remove cache_transceiver_config (not needed in agg mode) }
if "kv_cache_config" not in override_dict or not isinstance( ),
override_dict["kv_cache_config"], dict )
):
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = True
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args worker_service.extraPodSpec.mainContainer.args = args
...@@ -225,14 +216,7 @@ class TrtllmConfigModifier(BaseConfigModifier): ...@@ -225,14 +216,7 @@ class TrtllmConfigModifier(BaseConfigModifier):
# Break arguments to handle both joined strings and lists # Break arguments to handle both joined strings and lists
args = break_arguments(args) args = break_arguments(args)
# For TRT-LLM, we need to update the override-engine-args args = append_argument(args, _trtllm_flags({"tensor_parallel_size": tp_size}))
# to set the tensor_parallel_size
override_dict, args = parse_override_engine_args(args)
# Add/update tensor_parallel_size in the override
override_dict["tensor_parallel_size"] = tp_size
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args worker_service.extraPodSpec.mainContainer.args = args
...@@ -328,7 +312,7 @@ class TrtllmConfigModifier(BaseConfigModifier): ...@@ -328,7 +312,7 @@ class TrtllmConfigModifier(BaseConfigModifier):
) -> dict: ) -> dict:
""" """
Configure prefill-related limits for aggregated prefill runs. Configure prefill-related limits for aggregated prefill runs.
For TRT-LLM we set these via --override-engine-args JSON: For TRT-LLM we set these via ``--trtllm.*`` dynamic CLI flags:
- max_batch_size - max_batch_size
- max_num_tokens - max_num_tokens
""" """
...@@ -339,12 +323,15 @@ class TrtllmConfigModifier(BaseConfigModifier): ...@@ -339,12 +323,15 @@ class TrtllmConfigModifier(BaseConfigModifier):
args = validate_and_get_worker_args(worker_service, backend="trtllm") args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = break_arguments(args) args = break_arguments(args)
# Parse existing override-engine-args (if any) and update args = append_argument(
override_dict, args = parse_override_engine_args(args) args,
override_dict["max_batch_size"] = int(max_batch_size) _trtllm_flags(
override_dict["max_num_tokens"] = int(max_num_tokens) {
override_str = json.dumps(override_dict) "max_batch_size": int(max_batch_size),
args = append_argument(args, ["--override-engine-args", override_str]) "max_num_tokens": int(max_num_tokens),
}
),
)
worker_service.extraPodSpec.mainContainer.args = args worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump() return cfg.model_dump()
...@@ -13,8 +13,11 @@ DYNAMIC_FLAG_PREFIX = "--trtllm." ...@@ -13,8 +13,11 @@ DYNAMIC_FLAG_PREFIX = "--trtllm."
def infer_type(value: str) -> Any: def infer_type(value: str) -> Any:
"""Infer the Python type of a CLI value string. """Infer the Python type of a CLI value string.
Tries int, float, bool, then falls back to string. Tries None, int, float, bool, then falls back to string.
""" """
# none / null
if value.lower() in ("none", "null"):
return None
# int # int
try: try:
return int(value) return int(value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment