"tests/vscode:/vscode.git/clone" did not exist on "47b7af0d87705f2e086ea0bc9d915fc7510e8e2f"
Unverified Commit fe21d3dd authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

feat: Enable autodeploy as a backend for TRT-LLM (#4347)


Signed-off-by: default avatarWilliam Zhang <133824995+2ez4bz@users.noreply.github.com>
parent c55509ae
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import enum
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, Optional
from tensorrt_llm import LLM from tensorrt_llm import LLM
logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__)
class Backend(str, enum.Enum):
"""Supported TensorRT-LLM backend types."""
PYTORCH = "pytorch"
AUTODEPLOY = "_autodeploy"
class TensorRTLLMEngine: class TensorRTLLMEngine:
def __init__(self, engine_args): def __init__(self, engine_args):
self.engine_args = engine_args
self._llm: Optional[LLM] = None self._llm: Optional[LLM] = None
backend = engine_args.pop("backend", Backend.PYTORCH)
if backend == Backend.PYTORCH:
self._llm_cls = LLM
elif backend == Backend.AUTODEPLOY:
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
self._llm_cls = AutoDeployLLM
self._prune_engine_args_for_autodeploy(engine_args)
else:
raise ValueError(
f"Unsupported {backend=}. Available backends: {[b.value for b in Backend]}."
)
self.engine_args = engine_args
async def initialize(self): async def initialize(self):
if not self._llm: if not self._llm:
model = self.engine_args.pop("model") self._llm = self._llm_cls(**self.engine_args)
self._llm = LLM(
model=model,
**self.engine_args,
)
async def cleanup(self): async def cleanup(self):
if self._llm: if self._llm:
...@@ -38,6 +55,40 @@ class TensorRTLLMEngine: ...@@ -38,6 +55,40 @@ class TensorRTLLMEngine:
raise RuntimeError("Engine not initialized") raise RuntimeError("Engine not initialized")
return self._llm return self._llm
@staticmethod
def _prune_engine_args_for_autodeploy(engine_args) -> None:
"""Remove entries from `self.engine_args` that the autodeploy backend does not support."""
# TODO(2ez4bz/lucaslie): consider handling this in AutoDeploy's `LlmArgs` itself.
unsupported_fields = [
# https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/tensorrt_llm/_torch/auto_deploy/
# llm_args.py#L313
"build_config",
# https://github.com/NVIDIA/TensorRT-LLM/blob/b51258acdd968599b2c3756d5a5326e7d750e7bf/
# tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py#L384
"scheduler_config",
# The below all come from:
# https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/tensorrt_llm/_torch/auto_deploy/
# llm_args.py#L316
"tensor_parallel_size",
"pipeline_parallel_size",
"context_parallel_size",
"moe_cluster_parallel_size",
"moe_tensor_parallel_size",
"moe_expert_parallel_size",
"enable_attention_dp",
"cp_config",
]
for field_name in unsupported_fields:
if engine_args.pop(field_name, None) is not None:
TensorRTLLMEngine._warn_about_unsupported_field(field_name)
@staticmethod
def _warn_about_unsupported_field(field_name: str) -> None:
logger.warning(
"`%s` cannot be used with the `_autodeploy` backend. Ignoring.",
field_name,
)
@asynccontextmanager @asynccontextmanager
async def get_llm_engine(engine_args) -> AsyncGenerator[TensorRTLLMEngine, None]: async def get_llm_engine(engine_args) -> AsyncGenerator[TensorRTLLMEngine, None]:
......
...@@ -41,7 +41,7 @@ from dynamo.common.utils.prometheus import register_engine_metrics_callback ...@@ -41,7 +41,7 @@ from dynamo.common.utils.prometheus import register_engine_metrics_callback
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine from dynamo.trtllm.engine import Backend, TensorRTLLMEngine, get_llm_engine
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
from dynamo.trtllm.publisher import get_publisher from dynamo.trtllm.publisher import get_publisher
...@@ -181,7 +181,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -181,7 +181,7 @@ async def init(runtime: DistributedRuntime, config: Config):
"tensor_parallel_size": config.tensor_parallel_size, "tensor_parallel_size": config.tensor_parallel_size,
"pipeline_parallel_size": config.pipeline_parallel_size, "pipeline_parallel_size": config.pipeline_parallel_size,
"moe_expert_parallel_size": config.expert_parallel_size, "moe_expert_parallel_size": config.expert_parallel_size,
"backend": "pytorch", "backend": Backend.PYTORCH,
"skip_tokenizer_init": True, "skip_tokenizer_init": True,
"build_config": build_config, "build_config": build_config,
"kv_cache_config": kv_cache_config, "kv_cache_config": kv_cache_config,
...@@ -226,10 +226,12 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -226,10 +226,12 @@ async def init(runtime: DistributedRuntime, config: Config):
# Only pytorch backend is supported for now to publish events and metrics. # Only pytorch backend is supported for now to publish events and metrics.
if "backend" not in arg_map: if "backend" not in arg_map:
arg_map["backend"] = "pytorch" arg_map["backend"] = Backend.PYTORCH
elif arg_map["backend"] != "pytorch": elif arg_map["backend"] not in Backend:
logging.error( logging.error(
"Only pytorch backend is supported for now to publish events and metrics." "Only %s supported for now to publish events and metrics. Got: %s",
[b.value for b in Backend],
arg_map["backend"],
) )
sys.exit(1) sys.exit(1)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for autodeploy backend support in TRTLLM."""
import contextlib
from unittest import mock
import pydantic
import pytest
from tensorrt_llm._torch.auto_deploy import LlmArgs as ADLlmArgs
from dynamo.trtllm.engine import Backend, TensorRTLLMEngine, get_llm_engine
pytestmark = [
pytest.mark.unit,
pytest.mark.trtllm_marker,
# NOTE: these tests do not actually require a GPU, but the workflow validation
# `.github/workflows/container-validation-backends.yml` does not make use of
# the `gpu_0` marker.
pytest.mark.gpu_1,
]
_PYTORCH_LLM_CLS_NAME = "dynamo.trtllm.engine.LLM"
_AUTODEPLOY_LLM_CLS_NAME = "tensorrt_llm._torch.auto_deploy.LLM"
class TestTensorRTLLMEngine:
@pytest.mark.parametrize("backend", ["foo", "bar", "cpp"])
def test_raises_on_unsupported_backends(self, backend):
with pytest.raises(ValueError, match="Unsupported backend"):
TensorRTLLMEngine(engine_args={"backend": backend})
@pytest.mark.parametrize(
"backend, expected_cls_name",
[
("pytorch", _PYTORCH_LLM_CLS_NAME),
("_autodeploy", _AUTODEPLOY_LLM_CLS_NAME),
],
)
@pytest.mark.asyncio
async def test_picks_expected_llm_cls(self, backend, expected_cls_name):
with mock.patch(expected_cls_name) as mocked_cls:
engine = TensorRTLLMEngine(engine_args={"backend": backend})
await engine.initialize()
mocked_cls.assert_called_once()
@pytest.mark.parametrize(
"engine_args, is_forbidden",
[
({"build_config": {}}, True),
({"tensor_parallel_size": 7}, True),
({"pipeline_parallel_size": 3}, True),
({"context_parallel_size": 3}, True),
({"moe_cluster_parallel_size": 3}, True),
({"moe_tensor_parallel_size": 3}, True),
({"moe_expert_parallel_size": 3}, True),
({"enable_attention_dp": True}, True),
# Default value is an empty dict.
({"cp_config": {"foo", "bar"}}, True),
({"scheduler_config": {}}, False),
],
)
@pytest.mark.asyncio
async def test_unsupported_args_get_pruned_for_autodeploy(
self, engine_args, is_forbidden
):
engine_args["backend"] = Backend.AUTODEPLOY
# This allows us to catch cases where a field being pruned away is now supported by
# AutoDeploy when bumping TRT-LLM.
with pytest.raises(
pydantic.ValidationError
) if is_forbidden else contextlib.nullcontext():
ADLlmArgs(model="foo", **engine_args)
engine = TensorRTLLMEngine(engine_args=engine_args)
# This should no longer throw an error since the pruning should have kicked in.
ADLlmArgs(model="foo", **engine.engine_args)
@pytest.mark.parametrize("backend", ["pytorch", "_autodeploy"])
@pytest.mark.asyncio
async def test_get_llm_engine_forwards_backend(backend):
engine_args = {"foo": mock.Mock(), "backend": backend}
with mock.patch(
"dynamo.trtllm.engine.TensorRTLLMEngine", return_value=mock.AsyncMock()
) as mocked_engine:
async with get_llm_engine(engine_args=engine_args):
pass
mocked_engine.assert_called_once_with(engine_args)
...@@ -136,7 +136,7 @@ minversion = "8.0" ...@@ -136,7 +136,7 @@ minversion = "8.0"
tmp_path_retention_policy = "failed" tmp_path_retention_policy = "failed"
# NOTE # NOTE
# We ignore model.py explcitly here to avoid mypy errors with duplicate modules # We ignore model.py explicitly here to avoid mypy errors with duplicate modules
# pytest overrides the default mypy exclude configuration and so we exclude here as well # pytest overrides the default mypy exclude configuration and so we exclude here as well
addopts = [ addopts = [
"-ra", "-ra",
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
backend: _autodeploy
kv_cache_config:
enable_partial_reuse: false
free_gpu_memory_fraction: 0.80
max_tokens: 8192
compile_backend: torch-cudagraph
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Integration test for the autodeploy backend in TRTLLM."""
import logging
import os
import pathlib
import shutil
import pytest
import requests
from tests.utils.engine_process import FRONTEND_PORT
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
from tests.utils.payloads import check_models_api
logger = logging.getLogger(__name__)
# Just need a model to show the config works rather than any stress of the system.
MODEL_PATH = "Qwen/Qwen3-0.6B"
SERVED_MODEL_NAME = MODEL_PATH
PROMPT = "Takes skill to be real"
# TODO: turn into a fixture that _many_ tests can benefit from.
class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with TRTLLM backend."""
def __init__(self, request, worker_id: str, engine_config: str):
self.worker_id = worker_id
command = [
"python3",
"-m",
"dynamo.trtllm",
"--model",
MODEL_PATH,
"--served-model-name",
SERVED_MODEL_NAME,
"--extra-engine-args",
engine_config,
]
# Set debug logging environment
env = os.environ.copy()
env["DYN_LOG"] = "debug"
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
env["DYN_SYSTEM_PORT"] = "9345"
env["DYN_KVBM_CPU_CACHE_GB"] = "20"
env["DYN_KVBM_DISK_CACHE_GB"] = "60"
env["DYN_KVBM_LEADER_WORKER_INIT_TIMEOUT_SECS"] = "1200"
# TODO: Have the managed process take a command name explicitly to distinguish
# between processes started with the same command.
log_dir = f"{request.node.name}_{worker_id}"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
# Directory doesn't exist, which is fine
pass
super().__init__(
command=command,
env=env,
health_check_urls=[
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
("http://localhost:9345/health", self.is_ready),
],
timeout=360,
display_output=True,
terminate_existing=False,
log_dir=log_dir,
)
def get_pid(self) -> int | None:
"""Get the PID of the worker process"""
return self.proc.pid if hasattr(self, "proc") and self.proc else None
def is_ready(self, response) -> bool:
"""Check the health of the worker process"""
try:
data = response.json()
if data.get("status") == "ready":
logger.info(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} status is ready"
)
return True
logger.warning(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} status is not ready: {data.get('status')}"
)
except ValueError:
logger.warning(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} health response is not valid JSON"
)
return False
def __enter__(self):
"""Start the process and perform warmup request to trigger compilation.
Without a build cache, the autodeploy LLM engine will have to run some compilation before
being able to actually execute requests. We add a warmup stage here so that we can have
tighter timeouts on the requests sent during the actual tests.
"""
result = super().__enter__()
logger.info(
f"Sending warmup request to {self.worker_id} to trigger compilation..."
)
try:
warmup_response = send_completion_request(
prompt=PROMPT,
max_tokens=1,
timeout=300,
)
if warmup_response.ok:
logger.info(
f"Warmup request completed successfully for {self.worker_id}"
)
else:
raise RuntimeError(
f"Warmup request returned status {warmup_response.status_code} for {self.worker_id}"
)
except Exception as e:
logger.error(f"Warmup request failed for {self.worker_id}: {e}")
raise
return result
def send_completion_request(
prompt: str, max_tokens: int, timeout: int = 120
) -> requests.Response:
"""Send a completion request to the frontend"""
payload = {
"model": SERVED_MODEL_NAME,
"prompt": prompt,
"stream": False,
"max_tokens": max_tokens,
}
headers = {"Content-Type": "application/json"}
logger.info(
f"Sending completion request with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}"
)
try:
response = requests.post(
"http://localhost:8000/v1/completions",
headers=headers,
json=payload,
timeout=timeout,
)
return response
except requests.exceptions.Timeout:
logger.error(f"Request timed out after {timeout} seconds")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Request failed with error: {e}")
raise
@pytest.mark.trtllm_marker
@pytest.mark.e2e
@pytest.mark.slow
@pytest.mark.gpu_1
def test_smoke(request, runtime_services):
"""End-to-end test for TRTLLM worker with autodeploy backend in its most basic form."""
logger.info("Starting frontend...")
with DynamoFrontendProcess(request):
logger.info("Frontend started.")
engine_config_path = str(
pathlib.Path(__file__).parent / "autodeploy_engine_config.yaml"
)
logger.info("Starting worker...")
with DynamoWorkerProcess(request, "decode", engine_config_path) as worker:
logger.info(f"Worker PID: {worker.get_pid()}")
response = send_completion_request(
prompt=PROMPT, max_tokens=100, timeout=20
)
assert (
response.ok
), f"Expected successful status, got {response.status_code}"
logger.info(f"Completion request succeeded: {response.status_code}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment