Unverified Commit 1b1e089a authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: Enable dynamo-run out=trtllm (#1223)

parent fc31a510
......@@ -12,7 +12,7 @@
* [llama.cpp](#llamacpp)
* [Sglang](#sglang)
* [Vllm](#vllm)
* [TensorRT-LLM](#tensorrt-llm-engine)
* [TensorRT-LLM](#trtllm)
* [Echo Engines](#echo-engines)
* [Writing your own engine in Python](#writing-your-own-engine-in-python)
* [Batch mode](#batch-mode)
......@@ -437,10 +437,13 @@ Startup can be slow so you may want to `export DYN_LOG=debug` to see progress.
Shutdown: `ray stop`
#### TensorRT-LLM engine
#### trtllm
To run a TRT-LLM model with dynamo-run we have included a python based [async engine] (https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/engines/agg_engine.py).
To configure the TensorRT-LLM async engine please see [llm_api_config.yaml](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/configs/llm_api_config.yaml). The file defines the options that need to be passed to the LLM engine. Follow the steps below to serve trtllm on dynamo run.
Using [TensorRT-LLM's LLM API](https://nvidia.github.io/TensorRT-LLM/llm-api/), a high-level Python API.
You can use `--extra-engine-args` to pass extra arguments to LLM API engine.
The trtllm engine requires requires [etcd](https://etcd.io/) and [nats](https://nats.io/) with jetstream (`nats-server -js`) to be running.
##### Step 1: Build the environment
......@@ -454,7 +457,7 @@ See instructions [here](https://github.com/ai-dynamo/dynamo/blob/main/examples/t
Execute the following to load the TensorRT-LLM model specified in the configuration.
```
dynamo run out=pystr:/workspace/examples/tensorrt_llm/engines/trtllm_engine.py -- --engine_args /workspace/examples/tensorrt_llm/configs/llm_api_config.yaml
dynamo-run in=http out=trtllm TinyLlama/TinyLlama-1.1B-Chat-v1.0
```
#### Echo Engines
......@@ -529,6 +532,20 @@ Pass it like this:
```
dynamo-run out=sglang ~/llms/Llama-3.2-3B-Instruct --extra-engine-args sglang_extra.json
```
The tensorrtllm backend also support passing any argument the engine accepts. However, in this case config should be a yaml file.
```
backend: pytorch
kv_cache_config:
event_buffer_max_size: 1024
```
Pass it like this:
```
dynamo-run in=http out=trtllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args trtllm_extra.yaml
```
### Writing your own engine in Python
Note: This section replaces "bring-your-own-engine".
......
......@@ -223,6 +223,40 @@ pub async fn run(
}));
EngineConfig::Dynamic
}
Output::Trtllm => {
if flags.base_gpu_id != 0 {
anyhow::bail!("TRTLLM does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
// If `in=dyn` we want the trtllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we invent an internal one.
let endpoint = match &in_opt {
Input::Endpoint(path) => path.parse()?,
_ => INTERNAL_ENDPOINT.parse()?,
};
let (py_script, child) = match subprocess::start(
subprocess::trtllm::PY,
&local_model,
&endpoint,
flags.clone(),
None, // multi-node config. trtlllm uses `mpi`, see guide
)
.await
{
Ok(x) => x,
Err(err) => {
anyhow::bail!("Failed starting trtllm sub-process: {err}");
}
};
let cancel_token = cancel_token.clone();
// Sub-process cleanup
extra = Some(Box::pin(async move {
stopper(cancel_token, child, py_script).await;
}));
EngineConfig::Dynamic
}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {
......
......@@ -101,6 +101,9 @@ pub enum Output {
/// Run inference using sglang
SgLang,
/// Run inference using trtllm
Trtllm,
// Start vllm in a sub-process connecting via nats
// Sugar for `python vllm_inc.py --endpoint <thing> --model <thing>`
Vllm,
......@@ -125,6 +128,7 @@ impl TryFrom<&str> for Output {
"llamacpp" | "llama_cpp" => Ok(Output::LlamaCpp),
"sglang" => Ok(Output::SgLang),
"trtllm" => Ok(Output::Trtllm),
"vllm" => Ok(Output::Vllm),
"echo_full" => Ok(Output::EchoFull),
......@@ -164,6 +168,7 @@ impl fmt::Display for Output {
Output::LlamaCpp => "llamacpp",
Output::SgLang => "sglang",
Output::Trtllm => "trtllm",
Output::Vllm => "vllm",
Output::EchoFull => "echo_full",
......@@ -210,6 +215,7 @@ impl Output {
}
out.push(Output::SgLang.to_string());
out.push(Output::Trtllm.to_string());
out.push(Output::Vllm.to_string());
#[cfg(feature = "python")]
......
......@@ -15,6 +15,7 @@ use dynamo_llm::local_model::LocalModel;
use dynamo_runtime::protocols::Endpoint as EndpointId;
pub mod sglang;
pub mod trtllm;
pub mod vllm;
pub async fn start(
......
......@@ -2,16 +2,17 @@
# SPDX-License-Identifier: Apache-2.0
# TODO:
# - Add event and metrics publishers
# - Support default dynamo-run out=trtllm launch
# - Support disaggregated serving
# - Update examples to use this engine.
#
# `dynamo-run out=trtllm` runs this script
# Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params
import argparse
import asyncio
import logging
import sys
import warnings
from typing import Optional
import uvloop
......@@ -20,10 +21,13 @@ import uvloop
from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from trtllm.engine import get_llm_engine
from trtllm.publishers import Publishers
from dynamo.llm import ModelType, register_llm
from dynamo.llm import (
ModelType,
get_tensorrtllm_engine,
get_tensorrtllm_publisher,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
# Only used if you run it manually from the command line
......@@ -44,7 +48,7 @@ class Config:
component: str
endpoint: str
model_path: str
model_name: Optional[str]
model_name: Optional[str] = None
tensor_parallel_size: int
kv_block_size: int
extra_engine_args: str
......@@ -65,7 +69,9 @@ class RequestHandler:
async def generate(self, request):
# Check if there is an error in the publishers error queue
publishers_error = self.publishers.check_error_queue()
publishers_error = (
self.publishers.check_error_queue() if self.publishers else None
)
if publishers_error:
raise publishers_error
......@@ -90,7 +96,7 @@ class RequestHandler:
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self.first_generation and self.publishers:
self.publishers.start_publish_threads()
self.publishers.start()
self.first_generation = False
if res.finished:
......@@ -137,6 +143,7 @@ async def init(runtime: DistributedRuntime, config: Config):
"disable_log_stats": False,
}
if config.extra_engine_args != "":
# TODO: Support extra engine args from json file as well.
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
if config.publish_events_and_metrics:
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
......@@ -168,34 +175,33 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None
async with get_llm_engine(engine_args) as engine:
async with get_tensorrtllm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint)
await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
)
publishers = None
if config.publish_events_and_metrics:
# Initialize and pass in the publishers to the request handler to
# publish events and metrics.
kv_listener = runtime.namespace(config.namespace).component(
config.component
)
publishers = Publishers(
async with get_tensorrtllm_publisher(
component,
engine,
kv_listener,
int(endpoint.lease_id()),
config.kv_block_size,
) as publisher:
handler = RequestHandler(
component, engine, default_sampling_params, publisher
)
handler = RequestHandler(component, engine, default_sampling_params, publishers)
try:
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(handler.generate)
finally:
if publishers:
await publishers.cleanup()
else:
# No publishers, so just pass in None to the request handler.
handler = RequestHandler(component, engine, default_sampling_params, None)
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
......@@ -228,6 +234,12 @@ def cmd_line_args():
parser.add_argument(
"--kv-block-size", type=int, default=32, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
......@@ -241,6 +253,12 @@ def cmd_line_args():
)
args = parser.parse_args()
if args.context_length is not None:
warnings.warn(
"--context-length is accepted for compatibility but will be ignored for TensorRT-LLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
UserWarning,
)
config = Config()
config.model_path = args.model_path
if args.model_name:
......
......@@ -33,3 +33,13 @@ from dynamo._core import KvRouter as KvRouter
from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores
from dynamo._core import register_llm as register_llm
try:
from dynamo.llm.tensorrtllm import ( # noqa: F401
get_llm_engine as get_tensorrtllm_engine,
)
from dynamo.llm.tensorrtllm import ( # noqa: F401
get_publisher as get_tensorrtllm_publisher,
)
except ImportError:
pass # TensorRTLLM is not enabled by default
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .engine import get_llm_engine # noqa: F401
from .publisher import get_publisher # noqa: F401
......@@ -7,6 +7,7 @@ import logging
import threading
import traceback
import weakref
from contextlib import asynccontextmanager
from queue import Queue
from typing import Callable, Optional, Union
......@@ -80,7 +81,7 @@ class ManagedThread(threading.Thread):
self._current_future.cancel()
class Publishers:
class Publisher:
"""
A class to retrieve stats and kv cache events from TRTLLM engine and publish them to the metrics and events publishers.
"""
......@@ -102,7 +103,6 @@ class Publishers:
self.partial_block_hashes = set()
self.error_queue: Queue = Queue()
self._stop_event = threading.Event()
self._setup()
async def _create_metrics_publisher_endpoint(self):
logging.debug("Creating metrics publisher endpoint")
......@@ -111,7 +111,7 @@ class Publishers:
return
await self.metrics_publisher.create_endpoint(self.component)
def _setup(self):
def initialize(self):
# Setup the metrics publisher
self.metrics_publisher = KvMetricsPublisher()
self._init_publish_metrics_thread()
......@@ -298,7 +298,7 @@ class Publishers:
self.kv_event_publisher.publish_removed(event_id, block_hashes)
return True
def start_publish_threads(self):
def start(self):
if (
self.publish_kv_cache_events_thread
and not self.publish_kv_cache_events_thread.is_alive()
......@@ -342,3 +342,16 @@ class Publishers:
self.publish_kv_cache_events_thread.join(timeout=cleanup_timeout)
if self.publish_kv_cache_events_thread.is_alive():
logging.warning("KV cache events thread did not stop within timeout")
@asynccontextmanager
async def get_publisher(component, engine, kv_listener, worker_id, kv_block_size):
publisher = Publisher(component, engine, kv_listener, worker_id, kv_block_size)
try:
publisher.initialize()
yield publisher
except Exception as e:
logging.error(f"Error in engine context: {e}")
raise
finally:
await publisher.cleanup()
......@@ -141,6 +141,7 @@ addopts = [
"--ignore-glob=*model.py",
"--ignore-glob=*_inc.py",
"--ignore-glob=deploy/cloud/api-store/*",
"--ignore-glob=*/llm/tensorrtllm*",
# FIXME: Get relative/generic blob paths to work here
]
xfail_strict = true
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment