Unverified Commit 1b1e089a authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: Enable dynamo-run out=trtllm (#1223)

parent fc31a510
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
* [llama.cpp](#llamacpp) * [llama.cpp](#llamacpp)
* [Sglang](#sglang) * [Sglang](#sglang)
* [Vllm](#vllm) * [Vllm](#vllm)
* [TensorRT-LLM](#tensorrt-llm-engine) * [TensorRT-LLM](#trtllm)
* [Echo Engines](#echo-engines) * [Echo Engines](#echo-engines)
* [Writing your own engine in Python](#writing-your-own-engine-in-python) * [Writing your own engine in Python](#writing-your-own-engine-in-python)
* [Batch mode](#batch-mode) * [Batch mode](#batch-mode)
...@@ -437,10 +437,13 @@ Startup can be slow so you may want to `export DYN_LOG=debug` to see progress. ...@@ -437,10 +437,13 @@ Startup can be slow so you may want to `export DYN_LOG=debug` to see progress.
Shutdown: `ray stop` Shutdown: `ray stop`
#### TensorRT-LLM engine #### trtllm
To run a TRT-LLM model with dynamo-run we have included a python based [async engine] (https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/engines/agg_engine.py). Using [TensorRT-LLM's LLM API](https://nvidia.github.io/TensorRT-LLM/llm-api/), a high-level Python API.
To configure the TensorRT-LLM async engine please see [llm_api_config.yaml](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/configs/llm_api_config.yaml). The file defines the options that need to be passed to the LLM engine. Follow the steps below to serve trtllm on dynamo run.
You can use `--extra-engine-args` to pass extra arguments to LLM API engine.
The trtllm engine requires requires [etcd](https://etcd.io/) and [nats](https://nats.io/) with jetstream (`nats-server -js`) to be running.
##### Step 1: Build the environment ##### Step 1: Build the environment
...@@ -454,7 +457,7 @@ See instructions [here](https://github.com/ai-dynamo/dynamo/blob/main/examples/t ...@@ -454,7 +457,7 @@ See instructions [here](https://github.com/ai-dynamo/dynamo/blob/main/examples/t
Execute the following to load the TensorRT-LLM model specified in the configuration. Execute the following to load the TensorRT-LLM model specified in the configuration.
``` ```
dynamo run out=pystr:/workspace/examples/tensorrt_llm/engines/trtllm_engine.py -- --engine_args /workspace/examples/tensorrt_llm/configs/llm_api_config.yaml dynamo-run in=http out=trtllm TinyLlama/TinyLlama-1.1B-Chat-v1.0
``` ```
#### Echo Engines #### Echo Engines
...@@ -529,6 +532,20 @@ Pass it like this: ...@@ -529,6 +532,20 @@ Pass it like this:
``` ```
dynamo-run out=sglang ~/llms/Llama-3.2-3B-Instruct --extra-engine-args sglang_extra.json dynamo-run out=sglang ~/llms/Llama-3.2-3B-Instruct --extra-engine-args sglang_extra.json
``` ```
The tensorrtllm backend also support passing any argument the engine accepts. However, in this case config should be a yaml file.
```
backend: pytorch
kv_cache_config:
event_buffer_max_size: 1024
```
Pass it like this:
```
dynamo-run in=http out=trtllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args trtllm_extra.yaml
```
### Writing your own engine in Python ### Writing your own engine in Python
Note: This section replaces "bring-your-own-engine". Note: This section replaces "bring-your-own-engine".
......
...@@ -223,6 +223,40 @@ pub async fn run( ...@@ -223,6 +223,40 @@ pub async fn run(
})); }));
EngineConfig::Dynamic EngineConfig::Dynamic
} }
Output::Trtllm => {
if flags.base_gpu_id != 0 {
anyhow::bail!("TRTLLM does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
// If `in=dyn` we want the trtllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we invent an internal one.
let endpoint = match &in_opt {
Input::Endpoint(path) => path.parse()?,
_ => INTERNAL_ENDPOINT.parse()?,
};
let (py_script, child) = match subprocess::start(
subprocess::trtllm::PY,
&local_model,
&endpoint,
flags.clone(),
None, // multi-node config. trtlllm uses `mpi`, see guide
)
.await
{
Ok(x) => x,
Err(err) => {
anyhow::bail!("Failed starting trtllm sub-process: {err}");
}
};
let cancel_token = cancel_token.clone();
// Sub-process cleanup
extra = Some(Box::pin(async move {
stopper(cancel_token, child, py_script).await;
}));
EngineConfig::Dynamic
}
#[cfg(feature = "llamacpp")] #[cfg(feature = "llamacpp")]
Output::LlamaCpp => { Output::LlamaCpp => {
......
...@@ -101,6 +101,9 @@ pub enum Output { ...@@ -101,6 +101,9 @@ pub enum Output {
/// Run inference using sglang /// Run inference using sglang
SgLang, SgLang,
/// Run inference using trtllm
Trtllm,
// Start vllm in a sub-process connecting via nats // Start vllm in a sub-process connecting via nats
// Sugar for `python vllm_inc.py --endpoint <thing> --model <thing>` // Sugar for `python vllm_inc.py --endpoint <thing> --model <thing>`
Vllm, Vllm,
...@@ -125,6 +128,7 @@ impl TryFrom<&str> for Output { ...@@ -125,6 +128,7 @@ impl TryFrom<&str> for Output {
"llamacpp" | "llama_cpp" => Ok(Output::LlamaCpp), "llamacpp" | "llama_cpp" => Ok(Output::LlamaCpp),
"sglang" => Ok(Output::SgLang), "sglang" => Ok(Output::SgLang),
"trtllm" => Ok(Output::Trtllm),
"vllm" => Ok(Output::Vllm), "vllm" => Ok(Output::Vllm),
"echo_full" => Ok(Output::EchoFull), "echo_full" => Ok(Output::EchoFull),
...@@ -164,6 +168,7 @@ impl fmt::Display for Output { ...@@ -164,6 +168,7 @@ impl fmt::Display for Output {
Output::LlamaCpp => "llamacpp", Output::LlamaCpp => "llamacpp",
Output::SgLang => "sglang", Output::SgLang => "sglang",
Output::Trtllm => "trtllm",
Output::Vllm => "vllm", Output::Vllm => "vllm",
Output::EchoFull => "echo_full", Output::EchoFull => "echo_full",
...@@ -210,6 +215,7 @@ impl Output { ...@@ -210,6 +215,7 @@ impl Output {
} }
out.push(Output::SgLang.to_string()); out.push(Output::SgLang.to_string());
out.push(Output::Trtllm.to_string());
out.push(Output::Vllm.to_string()); out.push(Output::Vllm.to_string());
#[cfg(feature = "python")] #[cfg(feature = "python")]
......
...@@ -15,6 +15,7 @@ use dynamo_llm::local_model::LocalModel; ...@@ -15,6 +15,7 @@ use dynamo_llm::local_model::LocalModel;
use dynamo_runtime::protocols::Endpoint as EndpointId; use dynamo_runtime::protocols::Endpoint as EndpointId;
pub mod sglang; pub mod sglang;
pub mod trtllm;
pub mod vllm; pub mod vllm;
pub async fn start( pub async fn start(
......
...@@ -2,16 +2,17 @@ ...@@ -2,16 +2,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# TODO: # TODO:
# - Add event and metrics publishers
# - Support default dynamo-run out=trtllm launch
# - Support disaggregated serving # - Support disaggregated serving
# - Update examples to use this engine.
# #
# `dynamo-run out=trtllm` runs this script
# Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params # Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params
import argparse import argparse
import asyncio import asyncio
import logging import logging
import sys import sys
import warnings
from typing import Optional from typing import Optional
import uvloop import uvloop
...@@ -20,10 +21,13 @@ import uvloop ...@@ -20,10 +21,13 @@ import uvloop
from tensorrt_llm import SamplingParams from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from trtllm.engine import get_llm_engine
from trtllm.publishers import Publishers
from dynamo.llm import ModelType, register_llm from dynamo.llm import (
ModelType,
get_tensorrtllm_engine,
get_tensorrtllm_publisher,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
# Only used if you run it manually from the command line # Only used if you run it manually from the command line
...@@ -44,7 +48,7 @@ class Config: ...@@ -44,7 +48,7 @@ class Config:
component: str component: str
endpoint: str endpoint: str
model_path: str model_path: str
model_name: Optional[str] model_name: Optional[str] = None
tensor_parallel_size: int tensor_parallel_size: int
kv_block_size: int kv_block_size: int
extra_engine_args: str extra_engine_args: str
...@@ -65,7 +69,9 @@ class RequestHandler: ...@@ -65,7 +69,9 @@ class RequestHandler:
async def generate(self, request): async def generate(self, request):
# Check if there is an error in the publishers error queue # Check if there is an error in the publishers error queue
publishers_error = self.publishers.check_error_queue() publishers_error = (
self.publishers.check_error_queue() if self.publishers else None
)
if publishers_error: if publishers_error:
raise publishers_error raise publishers_error
...@@ -90,7 +96,7 @@ class RequestHandler: ...@@ -90,7 +96,7 @@ class RequestHandler:
# TRTLLM engine needs to start generating tokens first before stats # TRTLLM engine needs to start generating tokens first before stats
# can be retrieved. # can be retrieved.
if self.first_generation and self.publishers: if self.first_generation and self.publishers:
self.publishers.start_publish_threads() self.publishers.start()
self.first_generation = False self.first_generation = False
if res.finished: if res.finished:
...@@ -137,6 +143,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -137,6 +143,7 @@ async def init(runtime: DistributedRuntime, config: Config):
"disable_log_stats": False, "disable_log_stats": False,
} }
if config.extra_engine_args != "": if config.extra_engine_args != "":
# TODO: Support extra engine args from json file as well.
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args) arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
if config.publish_events_and_metrics: if config.publish_events_and_metrics:
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events. # 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
...@@ -168,34 +175,33 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -168,34 +175,33 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params._setup(tokenizer) default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None default_sampling_params.stop = None
async with get_llm_engine(engine_args) as engine: async with get_tensorrtllm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
await register_llm( await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name ModelType.Backend, endpoint, config.model_path, config.model_name
) )
publishers = None
if config.publish_events_and_metrics: if config.publish_events_and_metrics:
# Initialize and pass in the publishers to the request handler to
# publish events and metrics.
kv_listener = runtime.namespace(config.namespace).component( kv_listener = runtime.namespace(config.namespace).component(
config.component config.component
) )
publishers = Publishers( async with get_tensorrtllm_publisher(
component, component,
engine, engine,
kv_listener, kv_listener,
int(endpoint.lease_id()), int(endpoint.lease_id()),
config.kv_block_size, config.kv_block_size,
) ) as publisher:
handler = RequestHandler(
handler = RequestHandler(component, engine, default_sampling_params, publishers) component, engine, default_sampling_params, publisher
)
try: await endpoint.serve_endpoint(handler.generate)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes) else:
# after the lease is revoked # No publishers, so just pass in None to the request handler.
handler = RequestHandler(component, engine, default_sampling_params, None)
await endpoint.serve_endpoint(handler.generate) await endpoint.serve_endpoint(handler.generate)
finally:
if publishers:
await publishers.cleanup()
def cmd_line_args(): def cmd_line_args():
...@@ -228,6 +234,12 @@ def cmd_line_args(): ...@@ -228,6 +234,12 @@ def cmd_line_args():
parser.add_argument( parser.add_argument(
"--kv-block-size", type=int, default=32, help="Size of a KV cache block." "--kv-block-size", type=int, default=32, help="Size of a KV cache block."
) )
parser.add_argument(
"--context-length",
type=int,
default=None,
help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
)
parser.add_argument( parser.add_argument(
"--extra-engine-args", "--extra-engine-args",
type=str, type=str,
...@@ -241,6 +253,12 @@ def cmd_line_args(): ...@@ -241,6 +253,12 @@ def cmd_line_args():
) )
args = parser.parse_args() args = parser.parse_args()
if args.context_length is not None:
warnings.warn(
"--context-length is accepted for compatibility but will be ignored for TensorRT-LLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
UserWarning,
)
config = Config() config = Config()
config.model_path = args.model_path config.model_path = args.model_path
if args.model_name: if args.model_name:
......
...@@ -33,3 +33,13 @@ from dynamo._core import KvRouter as KvRouter ...@@ -33,3 +33,13 @@ from dynamo._core import KvRouter as KvRouter
from dynamo._core import ModelType as ModelType from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores from dynamo._core import OverlapScores as OverlapScores
from dynamo._core import register_llm as register_llm from dynamo._core import register_llm as register_llm
try:
from dynamo.llm.tensorrtllm import ( # noqa: F401
get_llm_engine as get_tensorrtllm_engine,
)
from dynamo.llm.tensorrtllm import ( # noqa: F401
get_publisher as get_tensorrtllm_publisher,
)
except ImportError:
pass # TensorRTLLM is not enabled by default
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .engine import get_llm_engine # noqa: F401
from .publisher import get_publisher # noqa: F401
...@@ -7,6 +7,7 @@ import logging ...@@ -7,6 +7,7 @@ import logging
import threading import threading
import traceback import traceback
import weakref import weakref
from contextlib import asynccontextmanager
from queue import Queue from queue import Queue
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
...@@ -80,7 +81,7 @@ class ManagedThread(threading.Thread): ...@@ -80,7 +81,7 @@ class ManagedThread(threading.Thread):
self._current_future.cancel() self._current_future.cancel()
class Publishers: class Publisher:
""" """
A class to retrieve stats and kv cache events from TRTLLM engine and publish them to the metrics and events publishers. A class to retrieve stats and kv cache events from TRTLLM engine and publish them to the metrics and events publishers.
""" """
...@@ -102,7 +103,6 @@ class Publishers: ...@@ -102,7 +103,6 @@ class Publishers:
self.partial_block_hashes = set() self.partial_block_hashes = set()
self.error_queue: Queue = Queue() self.error_queue: Queue = Queue()
self._stop_event = threading.Event() self._stop_event = threading.Event()
self._setup()
async def _create_metrics_publisher_endpoint(self): async def _create_metrics_publisher_endpoint(self):
logging.debug("Creating metrics publisher endpoint") logging.debug("Creating metrics publisher endpoint")
...@@ -111,7 +111,7 @@ class Publishers: ...@@ -111,7 +111,7 @@ class Publishers:
return return
await self.metrics_publisher.create_endpoint(self.component) await self.metrics_publisher.create_endpoint(self.component)
def _setup(self): def initialize(self):
# Setup the metrics publisher # Setup the metrics publisher
self.metrics_publisher = KvMetricsPublisher() self.metrics_publisher = KvMetricsPublisher()
self._init_publish_metrics_thread() self._init_publish_metrics_thread()
...@@ -298,7 +298,7 @@ class Publishers: ...@@ -298,7 +298,7 @@ class Publishers:
self.kv_event_publisher.publish_removed(event_id, block_hashes) self.kv_event_publisher.publish_removed(event_id, block_hashes)
return True return True
def start_publish_threads(self): def start(self):
if ( if (
self.publish_kv_cache_events_thread self.publish_kv_cache_events_thread
and not self.publish_kv_cache_events_thread.is_alive() and not self.publish_kv_cache_events_thread.is_alive()
...@@ -342,3 +342,16 @@ class Publishers: ...@@ -342,3 +342,16 @@ class Publishers:
self.publish_kv_cache_events_thread.join(timeout=cleanup_timeout) self.publish_kv_cache_events_thread.join(timeout=cleanup_timeout)
if self.publish_kv_cache_events_thread.is_alive(): if self.publish_kv_cache_events_thread.is_alive():
logging.warning("KV cache events thread did not stop within timeout") logging.warning("KV cache events thread did not stop within timeout")
@asynccontextmanager
async def get_publisher(component, engine, kv_listener, worker_id, kv_block_size):
publisher = Publisher(component, engine, kv_listener, worker_id, kv_block_size)
try:
publisher.initialize()
yield publisher
except Exception as e:
logging.error(f"Error in engine context: {e}")
raise
finally:
await publisher.cleanup()
...@@ -141,6 +141,7 @@ addopts = [ ...@@ -141,6 +141,7 @@ addopts = [
"--ignore-glob=*model.py", "--ignore-glob=*model.py",
"--ignore-glob=*_inc.py", "--ignore-glob=*_inc.py",
"--ignore-glob=deploy/cloud/api-store/*", "--ignore-glob=deploy/cloud/api-store/*",
"--ignore-glob=*/llm/tensorrtllm*",
# FIXME: Get relative/generic blob paths to work here # FIXME: Get relative/generic blob paths to work here
] ]
xfail_strict = true xfail_strict = true
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment