"vscode:/vscode.git/clone" did not exist on "44bde250656f68864dcd4942d8648454afc928fa"
Commit 4b42b232 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: LLM API example integration (#182)


Co-authored-by: default avatarNVShreyas <158103197+NVShreyas@users.noreply.github.com>
parent 03d0f6a2
......@@ -31,10 +31,6 @@ RUN apt-get update && \
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
ENV PATH="/root/.cargo/bin:${PATH}"
# 'etcd' is runtime dependency
RUN wget https://github.com/etcd-io/etcd/releases/download/v3.5.18/etcd-v3.5.18-linux-amd64.tar.gz && tar -xzf etcd-v3.5.18-linux-amd64.tar.gz
RUN cp ./etcd-v3.5.18-linux-amd64/etcd* /usr/local/bin/.
# Install OpenAI-compatible frontend and its dependencies from triton server
# repository. These are used to have a consistent interface, schema, and FastAPI
# app between Triton Core and Triton Distributed implementations.
......@@ -77,10 +73,11 @@ RUN pip install "git+https://github.com/triton-inference-server/perf_analyzer.gi
ARG FRAMEWORK="STANDARD"
ARG TENSORRTLLM_BACKEND_REPO_TAG=
ARG TENSORRTLLM_BACKEND_REBUILD=
ARG TENSORRTLLM_SKIP_CLONE=
ENV FRAMEWORK=${FRAMEWORK}
RUN --mount=type=bind,source=./container/deps/requirements.tensorrtllm.txt,target=/tmp/requirements.txt \
--mount=type=bind,source=./container/deps/clone_tensorrtllm.sh,target=/tmp/clone_tensorrtllm.sh \
if [[ "$FRAMEWORK" == "TENSORRTLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt; /tmp/clone_tensorrtllm.sh --tensorrtllm-backend-repo-tag ${TENSORRTLLM_BACKEND_REPO_TAG} --tensorrtllm-backend-rebuild ${TENSORRTLLM_BACKEND_REBUILD} --triton-llm-path /opt/triton/llm_binding ; fi
if [[ "$FRAMEWORK" == "TENSORRTLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt; if [ ${TENSORRTLLM_SKIP_CLONE} -ne 1 ] ; then /tmp/clone_tensorrtllm.sh --tensorrtllm-backend-repo-tag ${TENSORRTLLM_BACKEND_REPO_TAG} --tensorrtllm-backend-rebuild ${TENSORRTLLM_BACKEND_REBUILD} --triton-llm-path /opt/triton/llm_binding ; fi ; fi
RUN --mount=type=bind,source=./container/deps/requirements.standard.txt,target=/tmp/requirements.txt \
......@@ -114,6 +111,13 @@ ENV PYTHONUNBUFFERED=1
# Install NATS - pointing toward NATS github instead of binaries.nats.dev due to server instability
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
# etcd
ENV ETCD_VERSION="v3.5.18"
RUN wget https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-amd64.tar.gz -O /tmp/etcd.tar.gz && \
mkdir -p /usr/local/bin/etcd && \
tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1
ENV PATH=/usr/local/bin/etcd/:$PATH
# Enable Git operations in the /workspace directory.
RUN printf "[safe]\n directory=/workspace\n" > /root/.gitconfig
......
......@@ -65,7 +65,11 @@ TENSORRTLLM_BACKEND_REPO_TAG=triton-llm/v0.17.0
# Set this as 1 to rebuild and replace trtllm backend bits in the container.
# This will allow building triton distributed container image with custom
# trt-llm backend repo branch.
TENSORRTLLM_BACKEND_REBUILD=1
TENSORRTLLM_BACKEND_REBUILD=0
# Set this as 1 to skip cloning the trt-llm backend repo. If cloning is skipped, trt-llm
# backend repo tag and rebuild flag will be ignored. Use this option if you are using
# trtllm llmapi worker.
TENSORRTLLM_SKIP_CLONE=0
VLLM_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
VLLM_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
......@@ -109,6 +113,14 @@ get_options() {
missing_requirement $1
fi
;;
--skip-clone-tensorrtllm)
if [ "$2" ]; then
TENSORRTLLM_SKIP_CLONE=$2
shift
else
missing_requirement $1
fi
;;
--base-image)
if [ "$2" ]; then
BASE_IMAGE=$2
......@@ -241,6 +253,7 @@ show_image_options() {
if [[ $FRAMEWORK == "TENSORRTLLM" ]]; then
echo " Tensorrtllm Backend Repo Tag: '${TENSORRTLLM_BACKEND_REPO_TAG}'"
echo " Tensorrtllm Backend Rebuild: '${TENSORRTLLM_BACKEND_REBUILD}'"
echo " Tensorrtllm Skip Clone: '${TENSORRTLLM_SKIP_CLONE}'"
fi
echo " Build Context: '${BUILD_CONTEXT}'"
echo " Build Arguments: '${BUILD_ARGS}'"
......@@ -256,6 +269,7 @@ show_help() {
echo " [--framework framework one of ${!FRAMEWORKS[@]}]"
echo " [--tensorrtllm-backend-repo-tag commit or tag]"
echo " [--tensorrtllm-backend-rebuild whether or not to rebuild the backend]"
echo " [--skip-clone-tensorrtllm whether or not to skip cloning the trt-llm backend repo]"
echo " [--build-arg additional build args to pass to docker build]"
echo " [--tag tag for image]"
echo " [--no-cache disable docker build cache]"
......@@ -295,6 +309,7 @@ fi
if [[ $FRAMEWORK == "TENSORRTLLM" ]] && [ ! -z ${TENSORRTLLM_BACKEND_REPO_TAG} ]; then
BUILD_ARGS+=" --build-arg TENSORRTLLM_BACKEND_REPO_TAG=${TENSORRTLLM_BACKEND_REPO_TAG} "
BUILD_ARGS+=" --build-arg TENSORRTLLM_BACKEND_REBUILD=${TENSORRTLLM_BACKEND_REBUILD} "
BUILD_ARGS+=" --build-arg TENSORRTLLM_SKIP_CLONE=${TENSORRTLLM_SKIP_CLONE} "
fi
if [ ! -z ${HF_TOKEN} ]; then
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# TensorRT-LLM Integration with Triton Distributed
This example demonstrates how to use Triton Distributed to serve large language models with the tensorrt_llm engine, enabling efficient model serving with both monolithic and disaggregated deployment options.
## Prerequisites
Start required services (etcd and NATS):
Option A: Using [Docker Compose](/runtime/rust/docker-compose.yml) (Recommended)
```bash
docker-compose up -d
```
Option B: Manual Setup
- [NATS.io](https://docs.nats.io/running-a-nats-service/introduction/installation) server with [Jetstream](https://docs.nats.io/nats-concepts/jetstream)
- example: `nats-server -js --trace`
- [etcd](https://etcd.io) server
- follow instructions in [etcd installation](https://etcd.io/docs/v3.5/install/) to start an `etcd-server` locally
- example: `etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://0.0.0.0:2379`
## Building the Environment
TODO: Remove the internal references below.
- Build TRT-LLM wheel using latest tensorrt_llm main
```
git clone https://github.com/NVIDIA/TensorRT-LLM.git
cd TensorRT-LLM
# Start a dev docker container. Dont forget to mount your home directory to /home in the docker run command.
make -C docker jenkins_run LOCAL_USER=1 DOCKER_RUN_ARGS="-v /user/home:/home"
# Build wheel for the GPU architecture you are currently using ("native").
# We use -f to run fast build which should speed up the build process. But it might not work for all GPUs and for full functionality you should disable it.
python3 scripts/build_wheel.py --clean --trt_root /usr/local/tensorrt -a native -i -p -ccache
# Copy wheel to your local directory
cp build/tensorrt_llm-*.whl /home
```
- Build the Triton Distributed container
```bash
# Build image
./container/build.sh --base-image gitlab-master.nvidia.com:5005/dl/dgx/tritonserver/tensorrt-llm/amd64 --base-image-tag krish-fix-trtllm-build.23766174
```
Alternatively, you can build with latest tensorrt_llm pipeline like below:
```bash
# Build image
./container/build.sh --framework TENSORRTLLM --skip-clone-tensorrtllm 1 --base-image urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm-staging/release --base-image-tag main
```
**Note:** If you are using the latest tensorrt_llm image, you do not need to install the TRT-LLM wheel.
## Launching the Environment
```
# Run image interactively from with the triton distributed root directory.
./container/run.sh --framework TENSORRTLLM -it -v /home/:/home/
# Install the TRT-LLM wheel. No need to do this if you are using the latest tensorrt_llm image.
pip install /home/tensorrt_llm-*.whl
```
## Deployment Options
Note: NATS and ETCD servers should be running and accessible from the container as described in the [Prerequisites](#prerequisites) section.
### 1. Monolithic Deployment
Run the server and client components in separate terminal sessions:
**Server:**
Note: The following commands are tested on machines withH100x8 GPUs
#### Option 1.1 Single-Node Single-GPU
```bash
# Launch worker
cd /workspace/examples/python_rs/llm/tensorrt_llm
mpirun --allow-run-as-root -n 1 --oversubscribe python3 -m monolith.worker --engine_args model.json
```
Upon successful launch, the output should look similar to:
```bash
[TensorRT-LLM][INFO] KV cache block reuse is disabled
[TensorRT-LLM][INFO] Max KV cache pages per sequence: 2048
[TensorRT-LLM][INFO] Number of tokens per block: 64.
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 26.91 GiB for max tokens in paged KV cache (220480).
[02/14/2025-09:38:53] [TRT-LLM] [I] max_seq_len=131072, max_num_requests=2048, max_num_tokens=8192
[02/14/2025-09:38:53] [TRT-LLM] [I] Engine loaded and ready to serve...
```
`nvidia-smi` can be used to check the GPU usage and the model is loaded on single GPU.
#### Option 1.2 Single-Node Multi-GPU
Update `tensor_parallel_size` in the `model.json` to load the model with the desired number of GPUs.
For this example, we will load the model with 4 GPUs.
```bash
# Launch worker
cd /workspace/examples/python_rs/llm/tensorrt_llm
mpirun --allow-run-as-root -n 1 --oversubscribe python3 -m monolith.worker --engine_args model.json
```
`nvidia-smi` can be used to check the GPU usage and the model is loaded on 4 GPUs.
#### Option 1.3 Multi-Node Multi-GPU
Tanmay[WIP]
**Client:**
```bash
# Run client
python3 -m common.client \
--prompt "Describe the capital of France" \
--max-tokens 10 \
--temperature 0.5 \
--component tensorrt-llm
```
The output should look similar to:
```
Annotated(data=',', event=None, comment=[], id=None)
Annotated(data=', Paris', event=None, comment=[], id=None)
Annotated(data=', Paris,', event=None, comment=[], id=None)
Annotated(data=', Paris, in', event=None, comment=[], id=None)
Annotated(data=', Paris, in terms', event=None, comment=[], id=None)
Annotated(data=', Paris, in terms of', event=None, comment=[], id=None)
Annotated(data=', Paris, in terms of its', event=None, comment=[], id=None)
Annotated(data=', Paris, in terms of its history', event=None, comment=[], id=None)
Annotated(data=', Paris, in terms of its history,', event=None, comment=[], id=None)
Annotated(data=', Paris, in terms of its history, culture', event=None, comment=[], id=None)
```
### 2. Disaggregated Deployment
#### 2.1 Single-Node Disaggregated Deployment
**Environment**
This is the latest image with tensorrt_llm supporting distributed serving with pytorch workflow in LLM API.
Run the container interactively with the following command:
```bash
./container/run.sh --image IMAGE -it
```
**TRTLLM LLMAPI Disaggregated config file**
Define disaggregated config file similar to the example [single_node_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_config.yaml). The important sections are the model, context_servers and generation_servers.
**Launch the servers**
Launch context and generation servers.\
WORLD_SIZE is the total number of workers covering all the servers described in disaggregated configuration.\
For example, 2 TP2 generation servers are 2 servers but 4 workers/mpi executor.
```bash
cd /workspace/examples/python_rs/llm/tensorrt_llm/
mpirun --allow-run-as-root --oversubscribe -n WORLD_SIZE python3 -m disaggregated.worker --engine_args model.json -c disaggregated/llmapi_disaggregated_configs/single_node_config.yaml &
```
If using the provided [single_node_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_config.yaml), WORLD_SIZE should be 3 as it has 2 context servers(TP=1) and 1 generation server(TP=1).
**Launch the router**
```bash
cd /workspace/examples/python_rs/llm/tensorrt_llm/
python3 -m disaggregated.router -c disaggregated/llmapi_disaggregated_configs/single_node_config.yaml &
```
**Send Requests**
```bash
cd /workspace/examples/python_rs/llm/tensorrt_llm/
python3 -m common.client \
--prompt "Describe the capital of France" \
--max-tokens 10 \
--temperature 0.5 \
--component router
```
For more details on the disaggregated deployment, please refer to the [TRT-LLM example](#TODO).
### 3. Multi-Node Disaggregated Deployment
To run the disaggregated deployment across multiple nodes, we need to launch the servers using MPI, pass the correct NATS and etcd endpoints to each server and update the LLMAPI disaggregated config file to use the correct endpoints.
1. Allocate nodes
The following command allocates nodes for the job and returns the allocated nodes.
```bash
salloc -A ACCOUNT -N NUM_NODES -p batch -J JOB_NAME -t HH:MM:SS
```
You can use `squeue -u $USER` to check the URLs of the allocated nodes. These URLs should be added to the TRTLLM LLMAPI disaggregated config file as shown below.
```yaml
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
...
context_servers:
num_instances: 2
gpu_fraction: 0.25
tp_size: 2
pp_size: 1
urls:
- "node1:8001"
- "node2:8002"
generation_servers:
num_instances: 2
gpu_fraction: 0.25
tp_size: 2
pp_size: 1
urls:
- "node2:8003"
- "node2:8004"
```
2. Start the NATS and ETCD endpoints
Use the following commands. These commands will require downloading [NATS.io](https://docs.nats.io/running-a-nats-service/introduction/installation) and [ETCD](https://etcd.io/docs/v3.5/install/):
```bash
./nats-server -js --trace
./etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://0.0.0.0:2379
```
Export the correct NATS and etcd endpoints.
```bash
export NATS_SERVER="nats://node1:4222"
export ETCD_ENDPOINTS="http://node1:2379,http://node2:2379"
```
3. Launch the workers from node1 or login node. WORLD_SIZE is similar to single node deployment. Update the `model.json` to point to the new disagg config file.
```bash
srun --mpi pmix -N NUM_NODES --ntasks WORLD_SIZE --ntasks-per-node=WORLD_SIZE --no-container-mount-home --overlap --container-image IMAGE --output batch_%x_%j.log --err batch_%x_%j.err --container-mounts PATH_TO_TRITON_DISTRIBUTED:/workspace --container-env=NATS_SERVER,ETCD_ENDPOINTS bash -c 'cd /workspace/examples/python_rs/llm/tensorrt_llm && python3 -m disaggregated.worker --engine_args model.json -c disaggregated/llmapi_disaggregated_configs/multi_node_config.yaml' &
```
Once the workers are launched, you should see the output similar to the following in the worker logs.
```
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 18.88 GiB for max tokens in paged KV cache (1800032).
[02/20/2025-07:10:33] [TRT-LLM] [I] max_seq_len=2048, max_num_requests=2048, max_num_tokens=8192
[02/20/2025-07:10:33] [TRT-LLM] [I] Engine loaded and ready to serve...
[02/20/2025-07:10:33] [TRT-LLM] [I] max_seq_len=2048, max_num_requests=2048, max_num_tokens=8192
[TensorRT-LLM][INFO] Number of tokens per block: 32.
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 18.88 GiB for max tokens in paged KV cache (1800032).
[02/20/2025-07:10:33] [TRT-LLM] [I] max_seq_len=2048, max_num_requests=2048, max_num_tokens=8192
[02/20/2025-07:10:33] [TRT-LLM] [I] Engine loaded and ready to serve...
```
4. Launch the router from node1 or login node.
```bash
srun --mpi pmix -N 1 --ntasks 1 --ntasks-per-node=1 --overlap --container-image IMAGE --output batch_router_%x_%j.log --err batch_router_%x_%j.err --container-mounts PATH_TO_TRITON_DISTRIBUTED:/workspace --container-env=NATS_SERVER,ETCD_ENDPOINTS bash -c 'cd /workspace/examples/python_rs/llm/tensorrt_llm && python3 -m disaggregated.router -c disaggregated/llmapi_disaggregated_configs/multi_node_config.yaml' &
```
5. Send requests to the router.
```bash
srun --mpi pmix -N 1 --ntasks 1 --ntasks-per-node=1 --overlap --container-image IMAGE --output batch_client_%x_%j.log --err batch_client_%x_%j.err --container-mounts PATH_TO_TRITON_DISTRIBUTED:/workspace --container-env=NATS_SERVER,ETCD_ENDPOINTS bash -c 'cd /workspace/examples/python_rs/llm/tensorrt_llm && python3 -m common.client --prompt "Describe the capital of France" --max-tokens 10 --temperature 0.5 --component router' &
```
Finally, you should see the output similar to the following in the client logs.
```
Annotated(data='and', event=None, comment=[], id=None)
Annotated(data='and its', event=None, comment=[], id=None)
Annotated(data='and its significance', event=None, comment=[], id=None)
Annotated(data='and its significance in', event=None, comment=[], id=None)
Annotated(data='and its significance in the', event=None, comment=[], id=None)
Annotated(data='and its significance in the country', event=None, comment=[], id=None)
Annotated(data="and its significance in the country'", event=None, comment=[], id=None)
Annotated(data="and its significance in the country's", event=None, comment=[], id=None)
Annotated(data="and its significance in the country's history", event=None, comment=[], id=None)
Annotated(data="and its significance in the country's history.", event=None, comment=[], id=None)
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import uvloop
from triton_distributed.runtime import DistributedRuntime, triton_worker
from .protocol import Request
@triton_worker()
async def worker(
runtime: DistributedRuntime,
component: str,
prompt: str,
max_tokens: int,
temperature: float,
streaming: bool,
):
"""
Instantiate a `backend` client and call the `generate` endpoint
"""
# create client
client = (
await runtime.namespace("triton-init")
.component(component)
.endpoint("generate")
.client()
)
# list the endpoints
print(client.endpoint_ids())
# issue request
tasks = []
for _ in range(1):
tasks.append(
client.generate(
Request(
prompt=prompt,
sampling_params={
"temperature": temperature,
"max_tokens": max_tokens,
},
streaming=streaming,
).model_dump_json()
)
)
streams = await asyncio.gather(*tasks)
# process response
for stream in streams:
async for resp in stream:
print(resp)
if __name__ == "__main__":
uvloop.install()
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="what is the capital of france?")
parser.add_argument("--max-tokens", type=int, default=10)
parser.add_argument("--temperature", type=float, default=0.5)
parser.add_argument("--streaming", type=bool, default=True)
parser.add_argument(
"--component", type=str, default="router", help="component to send request to"
)
args = parser.parse_args()
asyncio.run(
worker(
args.component,
args.prompt,
args.max_tokens,
args.temperature,
args.streaming,
)
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from typing import Any, Dict, Tuple
# Define the expected keys for each config
# TODO: Add more keys as needed
PYTORCH_CONFIG_KEYS = {
"use_cuda_graph",
"cuda_graph_batch_sizes",
"cuda_graph_max_batch_size",
"cuda_graph_padding_enabled",
"enable_overlap_scheduler",
"kv_cache_dtype",
"torch_compile_enabled",
"torch_compile_fullgraph",
"torch_compile_inductor_enabled",
}
LLM_ENGINE_KEYS = {
"model",
"tokenizer",
"tokenizer_model",
"skip_tokenizer_init",
"trust_remote_code",
"tensor_parallel_size",
"dtype",
"revision",
"tokenizer_revision",
"speculative_model",
"enable_chunked_prefill",
}
def _get_llm_args(args_dict):
# Validation checks
for k, v in args_dict.items():
if (
k not in LLM_ENGINE_KEYS
and k not in PYTORCH_CONFIG_KEYS
and k != "copyright"
):
raise ValueError(f"Unrecognized key in --engine_args file: {k}")
pytorch_config_args = {
k: v for k, v in args_dict.items() if k in PYTORCH_CONFIG_KEYS and v is not None
}
llm_engine_args = {
k: v for k, v in args_dict.items() if k in LLM_ENGINE_KEYS and v is not None
}
if "model" not in llm_engine_args:
raise ValueError("Model name is required in the TRT-LLM engine config.")
return (pytorch_config_args, llm_engine_args)
def _init_engine_args(engine_args_filepath):
"""Initialize engine arguments from config file."""
if not os.path.isfile(engine_args_filepath):
raise ValueError(
f"'{engine_args_filepath}' containing TRT-LLM engine args must be provided in when launching the worker"
)
try:
with open(engine_args_filepath) as file:
trtllm_engine_config = json.load(file)
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to parse engine config: {e}")
return _get_llm_args(trtllm_engine_config)
def parse_tensorrt_llm_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]]:
parser = argparse.ArgumentParser(description="A TensorRT-LLM Worker parser")
parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file"
)
parser.add_argument(
"--llmapi-disaggregated-config",
"-c",
type=str,
help="Path to the llmapi disaggregated config file",
default=None,
)
args = parser.parse_args()
return (args, _init_engine_args(args.engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel
from tensorrt_llm.llmapi import DisaggregatedParams
class Request(BaseModel):
prompt: str
sampling_params: dict
streaming: bool = True
class Response(BaseModel):
text: str
class DisaggregatedRequest(Request):
disaggregated_params: dict = {}
class DisaggregatedResponse(Response):
disaggregated_params: DisaggregatedParams = {}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
context_servers:
num_instances: 2
gpu_fraction: 0.25
tp_size: 2
pp_size: 1
urls:
- "node1:8001"
- "node1:8002"
generation_servers:
num_instances: 2
gpu_fraction: 0.25
tp_size: 2
pp_size: 1
urls:
- "node2:8003"
- "node2:8004"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
context_servers:
num_instances: 2
gpu_fraction: 0.25
tp_size: 1
pp_size: 1
urls:
- "localhost:8001"
- "localhost:8002"
generation_servers:
num_instances: 1
gpu_fraction: 0.25
tp_size: 1
pp_size: 1
urls:
- "localhost:8002"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import copy
from dataclasses import asdict
import uvloop
from common.protocol import DisaggregatedRequest, DisaggregatedResponse, Response
from tensorrt_llm.llmapi import DisaggregatedParams
from tensorrt_llm.llmapi.disagg_utils import (
CtxGenServerConfig,
parse_disagg_config_file,
)
from tensorrt_llm.logger import logger
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
logger.set_level("info")
class Router:
def __init__(self, ctx_client, gen_client):
self.ctx_server_idx = 0
self.gen_server_idx = 0
self.ctx_client = ctx_client
self.gen_client = gen_client
logger.info("INITIALIZED ROUTER")
@triton_endpoint(DisaggregatedRequest, Response)
async def generate(self, request):
gen_req = copy.deepcopy(request)
# Send request to context server
request.disaggregated_params = asdict(
DisaggregatedParams(request_type="context_only")
)
request.sampling_params["max_tokens"] = 1
ctx_resp = [
resp
async for resp in await self.ctx_client.round_robin(
request.model_dump_json()
)
]
if len(ctx_resp) > 1:
raise ValueError(
"Context server returned more than one response. This is currently not supported in disaggregated server."
)
ctx_resp_obj = DisaggregatedResponse.parse_raw(ctx_resp[0].data())
if request.streaming:
# When streaming, the context server returns the first token and the rest of the tokens
# are returned in the generation server. We are return the first token here to ensure
# low TTFT
# NOTE: this might change in the future if trtllm context server returns raw tokens
yield ctx_resp_obj.text
gen_req.disaggregated_params = ctx_resp_obj.disaggregated_params
gen_req.disaggregated_params.request_type = "generation_only"
async for response in await self.gen_client.round_robin(
gen_req.model_dump_json()
):
yield response.data()
@triton_worker()
async def worker(runtime: DistributedRuntime, server_configs: list[CtxGenServerConfig]):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("triton-init").component("router")
await component.create_service()
ctx_client = (
await runtime.namespace("triton-init")
.component("tensorrt-llm-ctx")
.endpoint("generate")
.client()
)
gen_client = (
await runtime.namespace("triton-init")
.component("tensorrt-llm-gen")
.endpoint("generate")
.client()
)
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(Router(ctx_client, gen_client).generate)
if __name__ == "__main__":
uvloop.install()
parser = argparse.ArgumentParser()
parser.add_argument(
"--llmapi-disaggregated-config",
"-c",
type=str,
default="disaggregated/llmapi_disaggregated_configs/single_node_config.yaml",
help="Path to the llmapi disaggregated config file",
)
args = parser.parse_args()
disagg_config = parse_disagg_config_file(args.llmapi_disaggregated_config)
server_configs = disagg_config.server_configs
asyncio.run(worker(server_configs))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import threading
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Tuple
import uvloop
from common.parser import parse_tensorrt_llm_args
from common.protocol import DisaggregatedRequest, DisaggregatedResponse
from mpi4py.futures import MPICommExecutor
from mpi4py.MPI import COMM_WORLD
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm._utils import set_mpi_comm
from tensorrt_llm.llmapi import DisaggregatedParams, KvCacheConfig, MpiCommSession
from tensorrt_llm.llmapi.disagg_utils import (
CtxGenServerConfig,
DisaggServerConfig,
parse_disagg_config_file,
split_world_comm,
)
from tensorrt_llm.logger import logger
from triton_distributed.runtime import DistributedRuntime, triton_worker
logger.set_level("info")
class TensorrtLLMEngine:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
engine_args: Tuple[Dict[str, Any], Dict[str, Any]],
disagg_config: DisaggServerConfig,
instance_idx: int,
sub_comm,
):
self.pytorch_config_args, self.llm_engine_args = engine_args
self.disagg_config = disagg_config
self.instance_idx = instance_idx
self.server_config: CtxGenServerConfig = disagg_config.server_configs[
instance_idx
]
self.mpi_session = MpiCommSession(sub_comm, n_workers=sub_comm.Get_size())
self._init_engine()
def _init_engine(self):
logger.info("Initializing engine")
# Run the engine in a separate thread running the AsyncIO event loop.
self._llm_engine: Optional[Any] = None
self._llm_engine_start_cv = threading.Condition()
self._llm_engine_shutdown_event = asyncio.Event()
self._event_thread = threading.Thread(
target=asyncio.run, args=(self._run_llm_engine(),)
)
self._event_thread.start()
with self._llm_engine_start_cv:
while self._llm_engine is None:
self._llm_engine_start_cv.wait()
# The 'threading.Thread()' will not raise the exception here should the engine
# failed to start, so the exception is passed back via the engine variable.
if isinstance(self._llm_engine, Exception):
e = self._llm_engine
logger.error(f"Failed to start engine: {e}")
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
raise e
async def _run_llm_engine(self):
# Counter to keep track of ongoing request counts.
self._ongoing_request_count = 0
@asynccontextmanager
async def async_llm_wrapper():
# Create LLM in a thread to avoid blocking
loop = asyncio.get_running_loop()
try:
pytorch_config = PyTorchConfig(**self.pytorch_config_args)
# TODO: maybe add build config
llm = await loop.run_in_executor(
None,
lambda: LLM(
**self.llm_engine_args,
tensor_parallel_size=self.server_config.other_args["tp_size"],
pipeline_parallel_size=self.server_config.other_args["pp_size"],
gpus_per_node=None,
trust_remote_code=True,
_mpi_session=self.mpi_session,
kv_cache_config=KvCacheConfig(
free_gpu_memory_fraction=self.server_config.other_args[
"gpu_fraction"
]
),
pytorch_backend_config=pytorch_config,
backend="pytorch",
),
)
yield llm
finally:
if "llm" in locals():
# Run shutdown in a thread to avoid blocking
await loop.run_in_executor(None, llm.shutdown)
try:
async with async_llm_wrapper() as engine:
# Capture the engine event loop and make it visible to other threads.
self._event_loop = asyncio.get_running_loop()
# Signal the engine is started and make it visible to other threads.
with self._llm_engine_start_cv:
self._llm_engine = engine
self._llm_engine_start_cv.notify_all()
logger.info("Engine loaded and ready to serve...")
# Wait for the engine shutdown signal.
await self._llm_engine_shutdown_event.wait()
# Wait for the ongoing requests to complete.
while self._ongoing_request_count > 0:
logger.info(
"Awaiting remaining {} requests".format(
self._ongoing_request_count
)
)
await asyncio.sleep(1)
# Cancel all tasks in the event loop.
for task in asyncio.all_tasks(loop=self._event_loop):
if task is not asyncio.current_task():
task.cancel()
except Exception as e:
# Signal and pass the exception back via the engine variable if the engine
# failed to start. If the engine has started, re-raise the exception.
with self._llm_engine_start_cv:
if self._llm_engine is None:
self._llm_engine = e
self._llm_engine_start_cv.notify_all()
return
raise e
self._llm_engine = None
logger.info("Shutdown complete")
async def generate(self, request):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
self._ongoing_request_count += 1
logger.debug(f"Received request: {request}")
request = DisaggregatedRequest.parse_raw(request)
sampling_params = SamplingParams(**request.sampling_params)
disaggregated_params = DisaggregatedParams(**request.disaggregated_params)
# Opaque state is described as an additional state needing to be exchanged
# between context and gen instances
if disaggregated_params.opaque_state is not None:
disaggregated_params.opaque_state = (
disaggregated_params.opaque_state.encode("utf-8")
.decode("unicode_escape")
.encode("latin1")
)
async for response in self._llm_engine.generate_async(
request.prompt,
sampling_params,
streaming=request.streaming,
disaggregated_params=disaggregated_params,
):
logger.debug(f"Generated response: {response}")
if self.server_config.type == "ctx":
yield DisaggregatedResponse(
text=response.outputs[0].text,
disaggregated_params=response.outputs[0].disaggregated_params,
).model_dump_json()
else:
yield response.outputs[0].text
self._ongoing_request_count -= 1
@triton_worker()
async def worker(
runtime: DistributedRuntime,
engine_args: Tuple[Dict[str, Any], Dict[str, Any]],
disagg_config: DisaggServerConfig,
instance_idx: int,
sub_comm,
):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
server_type = disagg_config.server_configs[instance_idx].type
logger.info(f"Starting {server_type} server")
component = runtime.namespace("triton-init").component(
f"tensorrt-llm-{server_type}"
)
await component.create_service()
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(
TensorrtLLMEngine(engine_args, disagg_config, instance_idx, sub_comm).generate
)
if __name__ == "__main__":
uvloop.install()
args, engine_args = parse_tensorrt_llm_args()
if args.llmapi_disaggregated_config is None or not os.path.exists(
args.llmapi_disaggregated_config
):
raise ValueError(
"llmapi_disaggregated_config file does not exist or not provided"
)
disagg_config: DisaggServerConfig = parse_disagg_config_file(
args.llmapi_disaggregated_config
)
logger.info(f"Parsed disaggregated config: {disagg_config}")
is_leader, instance_idx, sub_comm = split_world_comm(disagg_config.server_configs)
os.environ["TRTLLM_USE_MPI_KVCACHE"] = "1"
set_mpi_comm(sub_comm)
logger.info(f"is_leader: {is_leader}, instance_idx: {instance_idx}")
if is_leader:
asyncio.run(worker(engine_args, disagg_config, instance_idx, sub_comm))
else:
with MPICommExecutor(sub_comm) as executor:
if not is_leader and executor is not None:
raise RuntimeError(f"rank{COMM_WORLD} should not have executor")
{
"copyright": [
"SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.",
"SPDX-License-Identifier: Apache-2.0",
"Licensed under the Apache License, Version 2.0 (the \"License\");",
"you may not use this file except in compliance with the License.",
"You may obtain a copy of the License at",
"http://www.apache.org/licenses/LICENSE-2.0",
"Unless required by applicable law or agreed to in writing, software",
"distributed under the License is distributed on an \"AS IS\" BASIS,",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
"See the License for the specific language governing permissions and",
"limitations under the License."
],
"model":"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"tokenizer": null,
"tokenizer_model": null,
"skip_tokenizer_init": null,
"trust_remote_code": null,
"tensor_parallel_size": null,
"dtype": null,
"revision": null,
"tokenizer_revision": null,
"speculative_model": null,
"enable_chunked_prefill": null,
"use_cuda_graph": null,
"cuda_graph_batch_sizes": null,
"cuda_graph_max_batch_size": null,
"cuda_graph_padding_enabled": null,
"enable_overlap_scheduler": null,
"kv_cache_dtype": null,
"torch_compile_enabled": null,
"torch_compile_fullgraph": null,
"torch_compile_inductor_enabled": null
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import threading
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Tuple
import uvloop
from common.parser import parse_tensorrt_llm_args
from common.protocol import Request, Response
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.logger import logger
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
logger.set_level("info")
class TensorrtLLMEngine:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine_args: Tuple[Dict[str, Any], Dict[str, Any]]):
self.pytorch_config_args, self.llm_engine_args = engine_args
self._init_engine()
def _init_engine(self):
logger.info("Initializing engine")
# Run the engine in a separate thread running the AsyncIO event loop.
self._llm_engine: Optional[Any] = None
self._llm_engine_start_cv = threading.Condition()
self._llm_engine_shutdown_event = asyncio.Event()
self._event_thread = threading.Thread(
target=asyncio.run, args=(self._run_llm_engine(),)
)
self._event_thread.start()
with self._llm_engine_start_cv:
while self._llm_engine is None:
self._llm_engine_start_cv.wait()
# The 'threading.Thread()' will not raise the exception here should the engine
# failed to start, so the exception is passed back via the engine variable.
if isinstance(self._llm_engine, Exception):
e = self._llm_engine
logger.error(f"Failed to start engine: {e}")
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
raise e
async def _run_llm_engine(self):
# Counter to keep track of ongoing request counts.
self._ongoing_request_count = 0
@asynccontextmanager
async def async_llm_wrapper():
# Create LLM in a thread to avoid blocking
loop = asyncio.get_running_loop()
try:
pytorch_config = PyTorchConfig(**self.pytorch_config_args)
llm = await loop.run_in_executor(
None,
lambda: LLM(
**self.llm_engine_args, pytorch_backend_config=pytorch_config
),
)
yield llm
finally:
if "llm" in locals():
# Run shutdown in a thread to avoid blocking
await loop.run_in_executor(None, llm.shutdown)
try:
async with async_llm_wrapper() as engine:
# Capture the engine event loop and make it visible to other threads.
self._event_loop = asyncio.get_running_loop()
# Signal the engine is started and make it visible to other threads.
with self._llm_engine_start_cv:
self._llm_engine = engine
self._llm_engine_start_cv.notify_all()
logger.info("Engine loaded and ready to serve...")
# Wait for the engine shutdown signal.
await self._llm_engine_shutdown_event.wait()
# Wait for the ongoing requests to complete.
while self._ongoing_request_count > 0:
logger.info(
"Awaiting remaining {} requests".format(
self._ongoing_request_count
)
)
await asyncio.sleep(1)
# Cancel all tasks in the event loop.
for task in asyncio.all_tasks(loop=self._event_loop):
if task is not asyncio.current_task():
task.cancel()
except Exception as e:
# Signal and pass the exception back via the engine variable if the engine
# failed to start. If the engine has started, re-raise the exception.
with self._llm_engine_start_cv:
if self._llm_engine is None:
self._llm_engine = e
self._llm_engine_start_cv.notify_all()
return
raise e
self._llm_engine = None
logger.info("Shutdown complete")
@triton_endpoint(Request, Response)
async def generate(self, request):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
self._ongoing_request_count += 1
logger.debug(f"Received request: {request}")
sampling_params = SamplingParams(**request.sampling_params)
async for response in self._llm_engine.generate_async(
request.prompt, sampling_params, streaming=request.streaming
):
logger.debug(f"Generated response: {response}")
yield response.outputs[0].text
self._ongoing_request_count -= 1
@triton_worker()
async def worker(
runtime: DistributedRuntime, engine_args: Tuple[Dict[str, Any], Dict[str, Any]]
):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("triton-init").component("tensorrt-llm")
await component.create_service()
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(TensorrtLLMEngine(engine_args).generate)
if __name__ == "__main__":
uvloop.install()
_, engine_args = parse_tensorrt_llm_args()
asyncio.run(worker(engine_args))
......@@ -104,8 +104,9 @@ indent-width = 4
# disable_error_code = []
# --explicit-package-bases: WAR errors about duplicate module names used
# throughout project such as launch_workers.py
# explicit_package_bases = true
# throughout the llm examples. For example, the common module in
# tensorrt_llm and vllm are both named common.
explicit_package_bases = true
# --ignore-missing-imports: WAR too many errors when developing outside
# of container environment with PYTHONPATH set and packages installed.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment