Unverified Commit dcee4dbd authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: support trtllm in sla-planner (#2980)


Signed-off-by: default avatarBiswa Panda <biswa.panda@gmail.com>
Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Signed-off-by: default avatarayushag <ayushag@nvidia.com>
Signed-off-by: default avatarDillon Cullinan <dcullinan@nvidia.com>
Signed-off-by: default avatarHarrison Saturley-Hall <hsaturleyhal@nvidia.com>
Signed-off-by: default avataralec-flowers <aflowers@nvidia.com>
Signed-off-by: default avatarJulien Mancuso <jmancuso@nvidia.com>
Signed-off-by: default avatarPavithra Vijayakrishnan <160681768+pvijayakrish@users.noreply.github.com>
Signed-off-by: default avatarHarry Kim <harry_kim@live.com>
Signed-off-by: default avatarTushar Sharma <tusharma@nvidia.com>
Signed-off-by: default avatarGuanLuo <gluo@nvidia.com>
Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
Signed-off-by: default avatarAnant Sharma <anants@nvidia.com>
Signed-off-by: default avatarGreg Clark <grclark@nvidia.com>
Signed-off-by: default avatarNeal Vaidya <nealv@nvidia.com>
Co-authored-by: default avatarBiswa Panda <biswa.panda@gmail.com>
Co-authored-by: default avatarAyush Agarwal <ayushag@nvidia.com>
Co-authored-by: default avatarDillon Cullinan <dcullinan92@gmail.com>
Co-authored-by: default avatarHarrison Saturley-Hall <hsaturleyhal@nvidia.com>
Co-authored-by: default avataralec-flowers <aflowers@nvidia.com>
Co-authored-by: default avatarjulienmancuso <161955438+julienmancuso@users.noreply.github.com>
Co-authored-by: default avatarPavithra Vijayakrishnan <160681768+pvijayakrish@users.noreply.github.com>
Co-authored-by: default avatarHarry Kim <harry_kim@live.com>
Co-authored-by: default avatarTushar Sharma <tusharma@nvidia.com>
Co-authored-by: default avatarAlec <35311602+alec-flowers@users.noreply.github.com>
Co-authored-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
Co-authored-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
Co-authored-by: default avatarGraham King <grahamk@nvidia.com>
Co-authored-by: default avatarAnant Sharma <anants@nvidia.com>
Co-authored-by: default avatarGreg Clark <grclark@nvidia.com>
Co-authored-by: default avatarNeal Vaidya <nealv@nvidia.com>
Co-authored-by: default avatarnv-nmailhot <nmailhot@nvidia.com>
parent 40000976
......@@ -59,7 +59,7 @@ Dynamo is designed to be inference engine agnostic (supports TRT-LLM, vLLM, SGLa
| [**Conditional Disaggregation**](/docs/architecture/disagg_serving.md#conditional-disaggregation) | 🚧 | 🚧 | 🚧 |
| [**KV-Aware Routing**](/docs/architecture/kv_cache_routing.md) | ✅ | ✅ | ✅ |
| [**Load Based Planner**](/docs/architecture/load_planner.md) | 🚧 | 🚧 | 🚧 |
| [**SLA-Based Planner**](/docs/architecture/sla_planner.md) | ✅ | ✅ | 🚧 |
| [**SLA-Based Planner**](/docs/architecture/sla_planner.md) | ✅ | ✅ | |
| [**KVBM**](/docs/architecture/kvbm_architecture.md) | ✅ | 🚧 | ✅ |
To learn more about each framework and their capabilities, check out each framework's README!
......
......@@ -548,8 +548,8 @@ if __name__ == "__main__":
"--backend",
type=str,
default="vllm",
choices=["vllm", "sglang"],
help="backend type, currently support [vllm, sglang]",
choices=["vllm", "sglang", "trtllm"],
help="backend type, currently support [vllm, sglang, trtllm]",
)
parser.add_argument(
"--config",
......
......@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import re
import shlex
from typing import Literal, Optional, Protocol
from pydantic import BaseModel
......@@ -82,11 +84,13 @@ def break_arguments(args: list[str] | None) -> list[str]:
if args is None:
return ans
if isinstance(args, str):
ans = re.split(r"[ =]", args)
# Use shlex.split to properly handle quoted arguments and JSON values
ans = shlex.split(args)
else:
for arg in args:
if arg is not None:
ans.extend(arg.split(" "))
# Use shlex.split to properly handle quoted arguments
ans.extend(shlex.split(arg))
return ans
......@@ -101,7 +105,8 @@ def remove_valued_arguments(args: list[str], key: str) -> list[str]:
def join_arguments(args: list[str]) -> list[str]:
return [" ".join(args)]
# Use shlex.join to properly quote arguments that contain spaces or special characters
return [shlex.join(args)]
def append_argument(args: list[str], to_append) -> list[str]:
......@@ -132,6 +137,43 @@ def find_arg_index(args: list[str]) -> int:
return idx
def parse_override_engine_args(args: list[str]) -> tuple[dict, list[str]]:
"""
Parse and extract --override-engine-args from argument list.
Returns:
tuple: (override_dict, modified_args) where override_dict is the parsed JSON
and modified_args is the args list with --override-engine-args removed
"""
override_dict = {}
try:
idx = args.index("--override-engine-args")
if idx + 1 < len(args):
# Parse existing override
override_dict = json.loads(args[idx + 1])
# Remove the old override args
del args[idx : idx + 2]
except (ValueError, json.JSONDecodeError):
pass # No existing override or invalid JSON
return override_dict, args
def deep_update(target: dict, source: dict) -> None:
"""
Recursively update nested dictionaries.
Args:
target: Dictionary to update
source: Dictionary with new values
"""
for key, value in source.items():
if isinstance(value, dict) and key in target and isinstance(target[key], dict):
deep_update(target[key], value)
else:
target[key] = value
class ConfigModifierProtocol(Protocol):
@classmethod
def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
......@@ -185,7 +227,7 @@ class VllmV1ConfigModifier:
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
f"Missing extraPodSpec or mainContainer in VLLM decode worker service '{WORKER_COMPONENT_NAMES['vllm'].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
......@@ -216,7 +258,7 @@ class VllmV1ConfigModifier:
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
f"Missing extraPodSpec or mainContainer in VLLM decode worker service '{WORKER_COMPONENT_NAMES['vllm'].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
......@@ -264,7 +306,9 @@ class VllmV1ConfigModifier:
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError("Missing extraPodSpec or mainContainer in worker service")
raise ValueError(
f"Missing extraPodSpec or mainContainer in VLLM decode worker service '{WORKER_COMPONENT_NAMES['vllm'].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -390,7 +434,7 @@ class SGLangConfigModifier:
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
f"Missing extraPodSpec or mainContainer in SGLang decode worker service '{WORKER_COMPONENT_NAMES['sglang'].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
......@@ -420,7 +464,7 @@ class SGLangConfigModifier:
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
f"Missing extraPodSpec or mainContainer in SGLang decode worker service '{WORKER_COMPONENT_NAMES['sglang'].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
......@@ -470,7 +514,9 @@ class SGLangConfigModifier:
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError("Missing extraPodSpec or mainContainer in worker service")
raise ValueError(
f"Missing extraPodSpec or mainContainer in SGLang decode worker service '{WORKER_COMPONENT_NAMES['sglang'].decode_worker_k8s_name}'"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -557,9 +603,261 @@ class SGLangConfigModifier:
return 0
class TrtllmConfigModifier:
@classmethod
def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
cfg = Config.model_validate(config)
# set metadata name
cfg.metadata.name = "trtllm-agg"
# disable planner
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# Convert to prefill-only aggregated setup
# Merge prefill worker config into a single worker
if "TRTLLMPrefillWorker" in cfg.spec.services:
# Rename prefill worker to generic worker
cfg.spec.services["TRTLLMWorker"] = cfg.spec.services[
"TRTLLMPrefillWorker"
]
del cfg.spec.services["TRTLLMPrefillWorker"]
# Remove decode worker
del cfg.spec.services["TRTLLMDecodeWorker"]
worker_service = cfg.spec.services["TRTLLMWorker"]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in TRTLLM worker service 'TRTLLMWorker'"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
# Remove disaggregation args
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (prefill.yaml) which may contain user settings
# Check if user already has override-engine-args and merge with our changes
override_dict, args = parse_override_engine_args(args)
# Merge our overrides for converting prefill-only disagg to aggregated:
# - Disable enable_block_reuse (no KV reuse for prefill-only)
# - Enable overlap scheduler (disabled in prefill.yaml but needed for agg)
# - Remove cache_transceiver_config (not needed in agg mode)
if "kv_cache_config" not in override_dict:
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = False
override_dict[
"disable_overlap_scheduler"
] = False # Enable overlap scheduler for agg
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode":
# Convert to decode-only aggregated setup
# Use decode worker as the main worker
if "TRTLLMDecodeWorker" in cfg.spec.services:
# Rename decode worker to generic worker
cfg.spec.services["TRTLLMWorker"] = cfg.spec.services[
"TRTLLMDecodeWorker"
]
del cfg.spec.services["TRTLLMDecodeWorker"]
# Remove prefill worker if exists
if "TRTLLMPrefillWorker" in cfg.spec.services:
del cfg.spec.services["TRTLLMPrefillWorker"]
worker_service = cfg.spec.services["TRTLLMWorker"]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in TRTLLM worker service 'TRTLLMWorker'"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
# Remove disaggregation args
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (decode.yaml) which may contain user settings
# Check if user already has override-engine-args and merge with our changes
override_dict, args = parse_override_engine_args(args)
# Merge our overrides for converting decode-only disagg to aggregated:
# - Enable enable_block_reuse (to skip prefill in decode-only)
# - Remove cache_transceiver_config (not needed in agg mode)
if "kv_cache_config" not in override_dict:
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = True
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
# Set num workers to 1
worker_config = cfg.spec.services["TRTLLMWorker"]
worker_config.replicas = 1
return cfg.model_dump()
@classmethod
def set_config_tp_size(cls, config: dict, tp_size: int):
cfg = Config.model_validate(config)
worker_service = cfg.spec.services["TRTLLMWorker"]
# Ensure resources exists
if worker_service.resources is None:
worker_service.resources = ServiceResources()
# Ensure requests exists
if worker_service.resources.requests is None:
worker_service.resources.requests = {}
worker_service.resources.requests["gpu"] = str(tp_size)
# Update limits if they exist
if worker_service.resources.limits is not None:
worker_service.resources.limits["gpu"] = str(tp_size)
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in TRTLLM worker service 'TRTLLMWorker'"
)
args = worker_service.extraPodSpec.mainContainer.args
# Break arguments to handle both joined strings and lists
args = break_arguments(args)
# For TRT-LLM, we need to update the override-engine-args
# to set the tensor_parallel_size
override_dict, args = parse_override_engine_args(args)
# Add/update tensor_parallel_size in the override
override_dict["tensor_parallel_size"] = tp_size
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
return cfg.model_dump()
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
worker_name = "TRTLLMWorker"
worker_service = cfg.spec.services.get(worker_name)
# Also check for disagg worker names
if not worker_service:
worker_name = "TRTLLMPrefillWorker"
worker_service = cfg.spec.services.get(worker_name)
if not worker_service:
worker_name = "TRTLLMDecodeWorker"
worker_service = cfg.spec.services.get(worker_name)
if not worker_service:
logger.warning(
f"Worker service not found, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
logger.warning(
f"Worker service missing extraPodSpec or mainContainer, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
for i, arg in enumerate(args):
if arg == "--served-model-name" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
# TRT-LLM frontend doesn't have args, it uses the default port
return DYNAMO_RUN_DEFAULT_PORT
@classmethod
def get_kv_cache_size_from_dynamo_log(cls, dynamo_log_fn: str) -> int:
# TRT-LLM log parsing for KV cache size
# Format: [TensorRT-LLM][INFO] [MemUsageChange] Allocated XX GiB for max tokens in paged KV cache (XXXXXX).
try:
with open(dynamo_log_fn, "r") as f:
for line in f:
# Look for the specific TRT-LLM KV cache allocation log
if (
"Allocated" in line
and "for max tokens in paged KV cache" in line
):
# Extract the number in parentheses at the end
match = re.search(r"paged KV cache \((\d+)\)", line)
if match:
max_tokens = int(match.group(1))
logger.info(
f"Found TRT-LLM KV cache max tokens: {max_tokens}"
)
return max_tokens
except Exception as e:
logger.warning(f"Failed to parse KV cache size from log file. Error: {e}")
# Return a reasonable default if we couldn't find the KV cache size in logs
logger.warning(
"Could not find KV cache size in TRT-LLM logs, using default value of 100000"
)
return 100000 # Default fallback value for TRT-LLM
CONFIG_MODIFIERS: dict[str, type[ConfigModifierProtocol]] = {
"vllm": VllmV1ConfigModifier,
"sglang": SGLangConfigModifier,
"trtllm": TrtllmConfigModifier,
}
# Re-export WORKER_COMPONENT_NAMES for profile_sla.py
......
......@@ -55,7 +55,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
| [**Disaggregated Serving**](../../../docs/architecture/disagg_serving.md) | ✅ | |
| [**Conditional Disaggregation**](../../../docs/architecture/disagg_serving.md#conditional-disaggregation) | 🚧 | Not supported yet |
| [**KV-Aware Routing**](../../../docs/architecture/kv_cache_routing.md) | ✅ | |
| [**SLA-Based Planner**](../../../docs/architecture/sla_planner.md) | 🚧 | Planned |
| [**SLA-Based Planner**](../../../docs/architecture/sla_planner.md) | | |
| [**Load Based Planner**](../../../docs/architecture/load_planner.md) | 🚧 | Planned |
| [**KVBM**](../../../docs/architecture/kvbm_architecture.md) | 🚧 | Planned |
......
......@@ -42,6 +42,19 @@ Aggregated deployment with custom configuration.
- `Frontend`: OpenAI-compatible API server (with kv router mode disabled)
- `TRTLLMWorker`: Single worker handling both prefill and decode with custom configuration mounted from the configmap
### 6. **Disaggregated Planner Deployment** (`disagg_planner.yaml`)
Advanced disaggregated deployment with SLA-based automatic scaling.
**Architecture:**
- `Frontend`: HTTP API server coordinating between workers
- `Planner`: SLA-based planner that monitors performance and scales workers automatically
- `Prometheus`: Metrics collection and monitoring
- `TRTLLMDecodeWorker`: Specialized decode-only worker
- `TRTLLMPrefillWorker`: Specialized prefill-only worker
> [!NOTE]
> This deployment requires pre-deployment profiling to be completed first. See [Pre-Deployment Profiling](../../../../docs/benchmarks/pre_deployment_profiling.md) for detailed instructions.
## CRD Structure
All templates use the **DynamoGraphDeployment** CRD:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
metadata:
name: trtllm-disagg-planner
spec:
envs:
- name: DYNAMO_SERVICE_CONFIG
value: '{"Prometheus":{"global":{"scrape_interval":"5s"},"scrape_configs":[{"job_name":"prometheus","static_configs":[{"targets":["localhost:8000"]}]},{"job_name":"frontend","static_configs":[{"targets":["trtllm-disagg-planner-frontend:8000"]}]}]}}'
- name: DYNAMO_NAMESPACE
value: "trtllm-disagg-planner"
services:
Frontend:
dynamoNamespace: trtllm-disagg-planner
componentType: frontend
replicas: 1
extraPodSpec:
mainContainer:
image: nvcr.io/nvidian/dynamo-dev/dynamo-trtllm-runtime:hzhou-0909-03
workingDir: /workspace/components/backends/trtllm
command:
- python3
args:
- -m
- dynamo.frontend
- --http-port
- "8000"
- --kv-cache-block-size
- "128"
- --router-mode
- kv
- --kv-overlap-score-weight
- "0.0"
- --router-temperature
- "0.0"
- --no-kv-events
Planner:
dynamoNamespace: trtllm-disagg-planner
envFromSecret: hf-token-secret
componentType: planner
replicas: 1
envs:
- name: PROMETHEUS_PORT
value: "8000"
livenessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
readinessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
initialDelaySeconds: 60
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
pvc:
create: false
name: dynamo-pvc # Must be pre-created before deployment and SLA profiler must have been run
mountPoint: /workspace/profiling_results
extraPodSpec:
mainContainer:
image: nvcr.io/nvidian/dynamo-dev/dynamo-trtllm-runtime:hzhou-0909-03
workingDir: /workspace/components/planner/src/dynamo/planner
ports:
- name: metrics
containerPort: 9085
command:
- python3
args:
- -m
- planner_sla
- --environment=kubernetes
- --backend=trtllm
- --adjustment-interval=60
- --profile-results-dir=/workspace/profiling_results
- --prometheus-port=9085
Prometheus: # NOTE: this is set on Prometheus to ensure a service is created for the Prometheus component. This is a workaround and should be managed differently.
dynamoNamespace: trtllm-disagg-planner
componentType: frontend
replicas: 1
envs:
- name: PYTHONPATH
value: "/workspace/components/planner/src"
- name: PROMETHEUS_PORT
value: "8000"
livenessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
readinessProbe:
exec:
command:
- /bin/sh
- -c
- "exit 0"
initialDelaySeconds: 30
periodSeconds: 60
timeoutSeconds: 30
failureThreshold: 10
extraPodSpec:
mainContainer:
image: nvcr.io/nvidian/dynamo-dev/dynamo-trtllm-runtime:hzhou-0909-03
workingDir: /workspace/components/backends/trtllm
command:
- python3
args:
- -m
- dynamo.planner.prometheus
TRTLLMDecodeWorker:
dynamoNamespace: trtllm-disagg-planner
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
livenessProbe:
httpGet:
path: /live
port: 9090
periodSeconds: 5
timeoutSeconds: 30
failureThreshold: 1
readinessProbe:
httpGet:
path: /health
port: 9090
periodSeconds: 10
timeoutSeconds: 30
failureThreshold: 60
resources:
limits:
gpu: "1"
extraPodSpec:
terminationGracePeriodSeconds: 600
mainContainer:
startupProbe:
httpGet:
path: /health
port: 9090
periodSeconds: 10
failureThreshold: 60
image: nvcr.io/nvidian/dynamo-dev/dynamo-trtllm-runtime:hzhou-0909-03
workingDir: /workspace/components/backends/trtllm
command:
- python3
args:
- -m
- dynamo.trtllm
- --model-path
- Qwen/Qwen3-0.6B
- --served-model-name
- Qwen/Qwen3-0.6B
- --extra-engine-args
- engine_configs/decode.yaml
- --disaggregation-mode
- decode
- --disaggregation-strategy
- decode_first
TRTLLMPrefillWorker:
dynamoNamespace: trtllm-disagg-planner
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
resources:
limits:
gpu: "1"
extraPodSpec:
terminationGracePeriodSeconds: 600
mainContainer:
startupProbe:
httpGet:
path: /health
port: 9090
periodSeconds: 10
failureThreshold: 60
image: nvcr.io/nvidian/dynamo-dev/dynamo-trtllm-runtime:hzhou-0909-03
workingDir: /workspace/components/backends/trtllm
command:
- python3
args:
- -m
- dynamo.trtllm
- --model-path
- Qwen/Qwen3-0.6B
- --served-model-name
- Qwen/Qwen3-0.6B
- --extra-engine-args
- engine_configs/prefill.yaml
- --disaggregation-mode
- prefill
- --disaggregation-strategy
- decode_first
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import logging
import os
import signal
......@@ -22,6 +23,7 @@ from torch.cuda import device_count
from transformers import AutoConfig
import dynamo.nixl_connect as nixl_connect
from benchmarks.profiler.utils.config import deep_update
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -192,6 +194,17 @@ async def init(runtime: DistributedRuntime, config: Config):
if config.extra_engine_args != "":
# TODO: Support extra engine args from json file as well.
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
# Apply override_engine_args if provided
if config.override_engine_args != "":
try:
overrides = json.loads(config.override_engine_args)
logging.info(f"Applying engine arg overrides: {overrides}")
deep_update(arg_map, overrides)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse override_engine_args as JSON: {e}")
sys.exit(1)
if config.publish_events_and_metrics:
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
kv_cache_config = None
......
......@@ -46,6 +46,7 @@ class Config:
self.max_beam_width: int = BuildConfig.max_beam_width
self.free_gpu_memory_fraction: Optional[float] = None
self.extra_engine_args: str = ""
self.override_engine_args: str = ""
self.publish_events_and_metrics: bool = False
self.disaggregation_mode: DisaggregationMode = DEFAULT_DISAGGREGATION_MODE
self.disaggregation_strategy: DisaggregationStrategy = (
......@@ -77,6 +78,7 @@ class Config:
f"max_beam_width={self.max_beam_width}, "
f"free_gpu_memory_fraction={self.free_gpu_memory_fraction}, "
f"extra_engine_args={self.extra_engine_args}, "
f"override_engine_args={self.override_engine_args}, "
f"migration_limit={self.migration_limit}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
......@@ -217,6 +219,12 @@ def cmd_line_args():
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
)
parser.add_argument(
"--override-engine-args",
type=str,
default="",
help='Python dictionary string to override specific engine arguments from the YAML file. Example: \'{"tensor_parallel_size": 2, "kv_cache_config": {"enable_block_reuse": false}}\'',
)
parser.add_argument(
"--publish-events-and-metrics",
action="store_true",
......@@ -352,6 +360,7 @@ def cmd_line_args():
config.kv_block_size = args.kv_block_size
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
config.override_engine_args = args.override_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics
config.modality = args.modality
......
......@@ -101,7 +101,22 @@ class SGLangComponentName:
decode_worker_endpoint = "generate"
class TrtllmComponentName:
# Note: Planner only supports DECODE_FIRST strategy in TRT-LLM:
# - Decode worker is the first worker (tensorrt_llm)
# - Prefill worker is the next worker (tensorrt_llm_next)
prefill_worker_k8s_name = "TRTLLMPrefillWorker"
prefill_worker_component_name = (
"tensorrt_llm_next" # Prefill is "next" with DECODE_FIRST
)
prefill_worker_endpoint = "generate"
decode_worker_k8s_name = "TRTLLMDecodeWorker"
decode_worker_component_name = "tensorrt_llm" # Decode is "first" with DECODE_FIRST
decode_worker_endpoint = "generate"
WORKER_COMPONENT_NAMES = {
"vllm": VllmComponentName,
"sglang": SGLangComponentName,
"trtllm": TrtllmComponentName,
}
......@@ -39,7 +39,7 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--backend",
default=SLAPlannerDefaults.backend,
choices=["vllm", "sglang"],
choices=["vllm", "sglang", "trtllm"],
help="Backend type",
)
parser.add_argument(
......
......@@ -237,6 +237,7 @@ COPY components/ /workspace/components/
COPY tests /workspace/tests
COPY benchmarks /workspace/benchmarks
COPY examples /workspace/examples
COPY deploy /workspace/deploy
RUN uv pip install /workspace/benchmarks
# Copy benchmarks, backends and tests for CI
......
......@@ -322,6 +322,10 @@ RUN . /opt/dynamo/venv/bin/activate && \
RUN pip install dist/ai_dynamo_runtime*cp312*.whl && \
pip install dist/ai_dynamo*any.whl
# Install common dependencies including aiofiles
RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requirements.txt \
pip install --requirement /tmp/requirements.txt
ENV DYNAMO_HOME=/workspace
# Copy launch banner
RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/launch_message.txt \
......
......@@ -24,7 +24,7 @@ import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import aiofiles # type: ignore[import-untyped]
import aiofiles
import httpx # added for HTTP requests
import kubernetes_asyncio as kubernetes
import yaml
......
......@@ -44,7 +44,7 @@ Key features include:
- ✅
- vLLM
* -
-
-
- TensorRT-LLM
* -
- ❌
......
......@@ -14,7 +14,7 @@ Support matrix:
| vLLM | MoE | 🚧 |
| SGLang | Dense | ✅ |
| SGLang | MoE | 🚧 |
| TensorRT-LLM | Dense | 🚧 |
| TensorRT-LLM | Dense | |
| TensorRT-LLM | MoE | 🚧 |
> [!NOTE]
......
......@@ -34,11 +34,13 @@ export NAMESPACE=your-namespace
## 1. Deploy the System
We use vllm as the backend engine in this guide. SLA planner also supports SGLang and will support TensorRT-LLM. Checkout `disagg_planner.yaml` in their example deployment folders for more details. The deployment is the same for all backends.
We use vllm as the backend engine in this guide. SLA planner also supports SGLang and TensorRT-LLM. Checkout `disagg_planner.yaml` in their example deployment folders for more details. The deployment is the same for all backends.
```bash
# Apply the disaggregated planner deployment
kubectl apply -f components/backends/vllm/deploy/disagg_planner.yaml -n $NAMESPACE # for vllm
# kubectl apply -f components/backends/sglang/deploy/disagg_planner.yaml -n $NAMESPACE # for sglang
# kubectl apply -f components/backends/trtllm/deploy/disagg_planner.yaml -n $NAMESPACE # for trtllm
# Check deployment status
kubectl get pods -n $NAMESPACE
......@@ -46,6 +48,7 @@ kubectl get pods -n $NAMESPACE
Expected pods (all should be `1/1 Running`):
```
# For vLLM:
vllm-disagg-planner-frontend-* 1/1 Running
vllm-disagg-planner-prometheus-* 1/1 Running
vllm-disagg-planner-planner-* 1/1 Running
......
......@@ -5,7 +5,7 @@
Test suite for profile_sla dry-run functionality.
This test ensures that the profile_sla script can successfully run in dry-run mode
for both vllm and sglang backends with their respective disagg.yaml configurations.
for vllm, sglang, and trtllm backends with their respective disagg.yaml configurations.
"""
import sys
......@@ -86,3 +86,35 @@ class TestProfileSLADryRun:
"""Test that profile_sla dry-run works for sglang backend with disagg.yaml config."""
# Run the profile in dry-run mode - should complete without errors
await run_profile(sglang_args)
@pytest.fixture
def trtllm_args(self):
"""Create arguments for trtllm backend dry-run test."""
class Args:
backend = "trtllm"
config = "components/backends/trtllm/deploy/disagg.yaml"
output_dir = "/tmp/test_profiling_results"
namespace = "test-namespace"
min_num_gpus_per_engine = 1
max_num_gpus_per_engine = 8
skip_existing_results = False
force_rerun = False
isl = 3000
osl = 500
ttft = 50
itl = 10
max_context_length = 16384
prefill_interpolation_granularity = 16
decode_interpolation_granularity = 6
service_name = ""
dry_run = True
return Args()
@pytest.mark.pre_merge
@pytest.mark.asyncio
async def test_trtllm_dryrun(self, trtllm_args):
"""Test that profile_sla dry-run works for trtllm backend with disagg.yaml config."""
# Run the profile in dry-run mode - should complete without errors
await run_profile(trtllm_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment