Commit b92834c8 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

chore: removing outdated examples (#202)

parent fd79234f
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# TensorRT-LLM Integration with Dynamo
This example demonstrates how to use Dynamo to serve large language models with the tensorrt_llm engine, enabling efficient model serving with both monolithic and disaggregated deployment options.
## Prerequisites
Start required services (etcd and NATS):
Option A: Using [Docker Compose](/runtime/rust/docker-compose.yml) (Recommended)
```bash
docker-compose up -d
```
Option B: Manual Setup
- [NATS.io](https://docs.nats.io/running-a-nats-service/introduction/installation) server with [Jetstream](https://docs.nats.io/nats-concepts/jetstream)
- example: `nats-server -js --trace`
- [etcd](https://etcd.io) server
- follow instructions in [etcd installation](https://etcd.io/docs/v3.5/install/) to start an `etcd-server` locally
- example: `etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://0.0.0.0:2379`
## Building the Environment
TODO: Remove the internal references below.
### Build the Dynamo container with latest TRT-LLM
#### Step 1:Build TRT-LLM wheel using latest tensorrt_llm main
```
git clone https://github.com/NVIDIA/TensorRT-LLM.git
cd TensorRT-LLM
# Start a dev docker container. Dont forget to mount your home directory to /home in the docker run command.
make -C docker jenkins_run LOCAL_USER=1 DOCKER_RUN_ARGS="-v /user/home:/home"
# Build wheel for the GPU architecture you are currently using ("native").
# We use -f to run fast build which should speed up the build process. But it might not work for all GPUs and for full functionality you should disable it.
python3 scripts/build_wheel.py --clean --trt_root /usr/local/tensorrt -a native -i -p -ccache
# Copy wheel to your local directory
cp build/tensorrt_llm-*.whl /home
```
####Step 2: Copy the TRT-LLM wheel to dynamo repository.
```bash
cp /home/tensorrt_llm-*.whl /<path-to-repo>/dynamo/trtllm_wheel/
```
####Step 3: Build the container
```bash
# Build image
./container/build.sh --framework TENSORRTLLM --tensorrtllm-pip-wheel-path trtllm_wheel
```
We need to copy the TRT-LLM wheel to repository and point the build script to the path within
the repository so that it can be picked by the docker build context.
## Launching the Environment
```
# Run image interactively from with the Dynamo root directory.
./container/run.sh --framework TENSORRTLLM -it
```
## Deployment Options
Note: NATS and ETCD servers should be running and accessible from the container as described in the [Prerequisites](#prerequisites) section.
### Monolithic Deployment
#### 1. HTTP Server
Run the server logging (with debug level logging):
```bash
DYN_LOG=DEBUG http &
```
By default the server will run on port 8080.
Add model to the server:
```bash
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0 dynamo.tensorrt-llm.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0 dynamo.tensorrt-llm.completions
```
#### 2. Workers
Note: The following commands are tested on machines withH100x8 GPUs
##### Option 2.1 Single-Node Single-GPU
```bash
# Launch worker
cd /workspace/examples/python_rs/llm/trtllm
mpirun --allow-run-as-root -n 1 --oversubscribe python3 -m monolith.launch --engine_args llm_api_config.yaml 1>agg_worker.log 2>&1 &
```
Upon successful launch, the output should look similar to:
```bash
[TensorRT-LLM][INFO] KV cache block reuse is disabled
[TensorRT-LLM][INFO] Max KV cache pages per sequence: 2048
[TensorRT-LLM][INFO] Number of tokens per block: 64.
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 26.91 GiB for max tokens in paged KV cache (220480).
[02/14/2025-09:38:53] [TRT-LLM] [I] max_seq_len=131072, max_num_requests=2048, max_num_tokens=8192
[02/14/2025-09:38:53] [TRT-LLM] [I] Engine loaded and ready to serve...
```
`nvidia-smi` can be used to check the GPU usage and the model is loaded on single GPU.
##### Option 2.2 Single-Node Multi-GPU
Update `tensor_parallel_size` in the `llm_api_config.yaml` to load the model with the desired number of GPUs.
`nvidia-smi` can be used to check the GPU usage and the model is loaded on 4 GPUs.
##### Option 2.3 Multi-Node Multi-GPU
TODO: Add multi-node multi-GPU example
#### 3. Client
```bash
# Chat Completion
curl localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"messages": [
{"role": "user", "content": "What is the capital of France?"}
]
}'
```
The output should look similar to:
```json
{
"id": "ab013077-8fb2-433e-bd7d-88133fccd497",
"choices": [
{
"message": {
"role": "assistant",
"content": "The capital of France is Paris."
},
"index": 0,
"finish_reason": "stop"
}
],
"created": 1740617803,
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "chat.completion",
"usage": null,
"system_fingerprint": null
}
```
```bash
# Completion
curl localhost:8080/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"prompt": "The capital of France is",
"max_tokens": 1,
"temperature": 0
}'
```
Output:
```json
{
"id":"cmpl-e0d75aca1bd540399809c9b609eaf010",
"choices":[
{
"text":"Paris",
"index":0,
"finish_reason":"length"
}
],
"created":1741024639,
"model":"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object":"text_completion",
"usage":null
}
```
### Disaggregated Deployment
**Environment**
This is the latest image with tensorrt_llm supporting distributed serving with pytorch workflow in LLM API.
Run the container interactively with the following command:
```bash
./container/run.sh --image IMAGE -it
```
#### 1. HTTP Server
Run the server logging (with debug level logging):
```bash
DYN_LOG=DEBUG http &
```
By default the server will run on port 8080.
Add model to the server:
```bash
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0 dynamo.router.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0 dynamo.router.completions
```
#### 2. Workers
##### Option 2.1 Single-Node Disaggregated Deployment
**TRTLLM LLMAPI Disaggregated config file**
Define disaggregated config file similar to the example [single_node_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_config.yaml). The important sections are the model, context_servers and generation_servers.
1. **Launch the servers**
Launch context and generation servers.\
WORLD_SIZE is the total number of workers covering all the servers described in disaggregated configuration.\
For example, 2 TP2 generation servers are 2 servers but 4 workers/mpi executor.
```bash
cd /workspace/examples/python_rs/llm/trtllm/
mpirun --allow-run-as-root --oversubscribe -n WORLD_SIZE python3 -m disaggregated.worker --engine_args llm_api_config.yaml -c disaggregated/llmapi_disaggregated_configs/single_node_config.yaml 1>disagg_workers.log 2>&1 &
```
If using the provided [single_node_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_config.yaml), WORLD_SIZE should be 2 as it has 1 context servers(TP=1) and 1 generation server(TP=1).
2. **Launch the router**
```bash
cd /workspace/examples/python_rs/llm/trtllm/
python3 -m disaggregated.router 1>router.log 2>&1 &
```
Note: For KV cache aware routing, please refer to the [KV Aware Routing](./docs/kv_aware_routing.md) section.
3. **Send Requests**
Follow the instructions in the [Monolithic Deployment](#3-client) section to send requests to the router.
For more details on the disaggregated deployment, please refer to the [TRT-LLM example](#TODO).
### Multi-Node Disaggregated Deployment
To run the disaggregated deployment across multiple nodes, we need to launch the servers using MPI, pass the correct NATS and etcd endpoints to each server and update the LLMAPI disaggregated config file to use the correct endpoints.
1. Allocate nodes
The following command allocates nodes for the job and returns the allocated nodes.
```bash
salloc -A ACCOUNT -N NUM_NODES -p batch -J JOB_NAME -t HH:MM:SS
```
You can use `squeue -u $USER` to check the URLs of the allocated nodes. These URLs should be added to the TRTLLM LLMAPI disaggregated config file as shown below.
```yaml
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
...
context_servers:
num_instances: 2
gpu_fraction: 0.25
tp_size: 2
pp_size: 1
urls:
- "node1:8001"
- "node2:8002"
generation_servers:
num_instances: 2
gpu_fraction: 0.25
tp_size: 2
pp_size: 1
urls:
- "node2:8003"
- "node2:8004"
```
2. Start the NATS and ETCD endpoints
Use the following commands. These commands will require downloading [NATS.io](https://docs.nats.io/running-a-nats-service/introduction/installation) and [ETCD](https://etcd.io/docs/v3.5/install/):
```bash
./nats-server -js --trace
./etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://0.0.0.0:2379
```
Export the correct NATS and etcd endpoints.
```bash
export NATS_SERVER="nats://node1:4222"
export ETCD_ENDPOINTS="http://node1:2379,http://node2:2379"
```
3. Launch the workers from node1 or login node. WORLD_SIZE is similar to single node deployment.
```bash
srun --mpi pmix -N NUM_NODES --ntasks WORLD_SIZE --ntasks-per-node=WORLD_SIZE --no-container-mount-home --overlap --container-image IMAGE --output batch_%x_%j.log --err batch_%x_%j.err --container-mounts PATH_TO_DYNAMO:/workspace --container-env=NATS_SERVER,ETCD_ENDPOINTS bash -c 'cd /workspace/examples/python_rs/llm/trtllm && python3 -m disaggregated.worker --engine_args llm_api_config.yaml -c disaggregated/llmapi_disaggregated_configs/multi_node_config.yaml' &
```
Once the workers are launched, you should see the output similar to the following in the worker logs.
```
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 18.88 GiB for max tokens in paged KV cache (1800032).
[02/20/2025-07:10:33] [TRT-LLM] [I] max_seq_len=2048, max_num_requests=2048, max_num_tokens=8192
[02/20/2025-07:10:33] [TRT-LLM] [I] Engine loaded and ready to serve...
[02/20/2025-07:10:33] [TRT-LLM] [I] max_seq_len=2048, max_num_requests=2048, max_num_tokens=8192
[TensorRT-LLM][INFO] Number of tokens per block: 32.
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 18.88 GiB for max tokens in paged KV cache (1800032).
[02/20/2025-07:10:33] [TRT-LLM] [I] max_seq_len=2048, max_num_requests=2048, max_num_tokens=8192
[02/20/2025-07:10:33] [TRT-LLM] [I] Engine loaded and ready to serve...
```
4. Launch the router from node1 or login node.
```bash
srun --mpi pmix -N 1 --ntasks 1 --ntasks-per-node=1 --overlap --container-image IMAGE --output batch_router_%x_%j.log --err batch_router_%x_%j.err --container-mounts PATH_TO_DYNAMO:/workspace --container-env=NATS_SERVER,ETCD_ENDPOINTS bash -c 'cd /workspace/examples/python_rs/llm/trtllm && python3 -m disaggregated.router' &
```
5. Send requests to the router.
The router will connect to the OAI compatible server. You can send requests to the router using the standard OAI format as shown in previous sections.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import threading
from contextlib import asynccontextmanager
from dataclasses import dataclass
from queue import Queue
from typing import Any, Optional
from common.parser import LLMAPIConfig
from common.processor import ChatProcessor, CompletionsProcessor
from common.utils import ManagedThread
from tensorrt_llm._torch import LLM
from tensorrt_llm.logger import logger
from transformers import AutoTokenizer
from dynamo.llm import KvMetricsPublisher
from .kv_cache_event_publisher import KVCacheEventPublisher
class ChatProcessorMixin:
def __init__(self, engine_config: LLMAPIConfig):
self._engine_config = engine_config
logger.info(f"Using LLM API config: {self._engine_config.to_dict()}")
# model name for chat processor
self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}")
# model for LLMAPI input
self._model = self._model_name
if self._engine_config.model_path:
self._model = self._engine_config.model_path
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_path
)
logger.info(f"Using model from path: {self._engine_config.model_path}")
else:
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_name
)
if self._engine_config.extra_args.get("tokenizer", None):
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.extra_args.get("tokenizer", None)
)
self.chat_processor = ChatProcessor(self._model_name, self._tokenizer)
self.completions_processor = CompletionsProcessor(self._model_name)
@dataclass
class TensorrtLLMEngineConfig:
namespace_str: str = "dynamo"
component_str: str = "tensorrt-llm"
engine_config: LLMAPIConfig = None
worker_id: Optional[str] = None
kv_metrics_publisher: Optional[KvMetricsPublisher] = None
publish_stats: bool = False
publish_kv_cache_events: bool = False
class BaseTensorrtLLMEngine(ChatProcessorMixin):
def __init__(
self,
trt_llm_engine_config: TensorrtLLMEngineConfig,
):
super().__init__(trt_llm_engine_config.engine_config)
self._namespace_str = trt_llm_engine_config.namespace_str
self._component_str = trt_llm_engine_config.component_str
self._worker_id = trt_llm_engine_config.worker_id
self._kv_metrics_publisher = trt_llm_engine_config.kv_metrics_publisher
self._publish_stats = trt_llm_engine_config.publish_stats
self._publish_kv_cache_events = trt_llm_engine_config.publish_kv_cache_events
self._error_queue: Optional[Queue] = None
self._init_engine()
def _init_engine(self):
logger.info("Initializing engine")
# Run the engine in a separate thread running the AsyncIO event loop.
self._llm_engine: Optional[Any] = None
self._llm_engine_start_cv = threading.Condition()
self._llm_engine_shutdown_event = asyncio.Event()
self._event_thread = threading.Thread(
target=asyncio.run, args=(self._run_llm_engine(),)
)
self.publish_kv_cache_events_thread = None
self.publish_stats_thread = None
self._event_thread.start()
with self._llm_engine_start_cv:
while self._llm_engine is None:
self._llm_engine_start_cv.wait()
# The 'threading.Thread()' will not raise the exception here should the engine
# failed to start, so the exception is passed back via the engine variable.
if isinstance(self._llm_engine, Exception):
e = self._llm_engine
logger.error(f"Failed to start engine: {e}")
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
raise e
self._error_queue = Queue()
try:
if self._publish_stats:
self._init_publish_metrics_thread()
if self._publish_kv_cache_events:
self._init_publish_kv_cache_events_thread()
except Exception as e:
logger.error(f"Failed to initialize publish metrics threads: {e}")
raise e
def _init_publish_metrics_thread(self):
# Need to publish stats once so that worker can be selected.
# Publishing some dummy values...
request_active_slots = 0
request_total_slots = 4
kv_active_block = 0
kv_total_blocks = 4
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!")
return
self._kv_metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
# Prepare threads for publishing stats but don't start them yet.
# TRTLLM needs to start generating tokens first before stats
# can be retrieved.
self.publish_stats_thread = ManagedThread(
self.publish_stats_task,
error_queue=self._error_queue,
name="publish_stats_thread",
)
def _init_publish_kv_cache_events_thread(self):
if self._worker_id is None:
logger.error("Worker ID not initialized!")
return
# TODO: Use python bindings to publish kv cache events once they
# are available.
lib_path = "/opt/dynamo/bindings/lib/libdynamo_llm_capi.so"
self._kv_cache_events_publisher = KVCacheEventPublisher(
self._namespace_str, self._component_str, int(self._worker_id), lib_path
)
# Prepare threads for publishing kv cache events but don't start them yet.
# TRTLLM needs to start generating tokens first before kv cache events
# can be retrieved.
self.publish_kv_cache_events_thread = ManagedThread(
self.publish_kv_cache_events_task,
error_queue=self._error_queue,
name="publish_kv_cache_events_thread",
)
async def publish_stats_task(self):
"""
Publish stats to the metrics publisher.
"""
if self._llm_engine is None:
logger.error("LLM engine not initialized!")
return
stats = self._llm_engine.get_stats_async(timeout=5)
async for stat in stats:
request_active_slots = stat["numActiveRequests"]
request_total_slots = stat["maxNumActiveRequests"]
kv_active_block = stat["kvCacheStats"]["usedNumBlocks"]
kv_total_blocks = stat["kvCacheStats"]["maxNumBlocks"]
if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!")
return False
# TODO: Remove this once we have the actual values.
# Adding dummy values for now so it doesn't break the metrics.
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
self._kv_metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
logger.debug(
f"Published stats: request_active_slots: {request_active_slots}, request_total_slots: {request_total_slots}, kv_active_block: {kv_active_block}, kv_total_blocks: {kv_total_blocks}"
)
return True
async def publish_kv_cache_events_task(self):
"""
Publish kv cache events to the events publisher.
"""
if self._llm_engine is None:
logger.error("LLM engine not initialized!")
return
events = self._llm_engine.get_kv_cache_events_async(timeout=5)
async for event_list in events:
for event in event_list:
logger.debug(f"Received event from llmapi: {event}")
id = event["event_id"]
data = event["data"]
if data["type"] == "stored":
parent_hash = data["parent_hash"]
token_ids = []
block_hashes = []
for block in data["blocks"]:
block_hash = block["block_hash"]
block_hashes.append(block_hash)
for token in block["tokens"]:
# TODO: How to handle token_extra_id?
token_ids.append(token["token_id"])
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
lora_id = data.get("lora_id", 0)
# Publish the stored event
self._kv_cache_events_publisher.stored_event(
id, parent_hash, block_hashes, token_ids, lora_id
)
logger.debug(
f"Published stored event: {id}, parent_hash: {parent_hash}, block_hashes: {block_hashes}, token_ids: {token_ids}"
)
elif data["type"] == "removed":
# Publish the removed event
block_hashes = []
for block_hash in data["block_hashes"]:
block_hashes.append(block_hash)
self._kv_cache_events_publisher.removed_event(id, block_hashes)
logger.debug(
f"Published removed event: {id}, block_hashes: {block_hashes}"
)
return True
async def _run_llm_engine(self):
# Counter to keep track of ongoing request counts.
self._ongoing_request_count = 0
@asynccontextmanager
async def async_llm_wrapper():
# Create LLM in a thread to avoid blocking
loop = asyncio.get_running_loop()
try:
llm = await loop.run_in_executor(
None,
lambda: LLM(model=self._model, **self._engine_config.to_dict()),
)
yield llm
finally:
if "llm" in locals():
# Run shutdown in a thread to avoid blocking
await loop.run_in_executor(None, llm.shutdown)
try:
async with async_llm_wrapper() as engine:
# Capture the engine event loop and make it visible to other threads.
self._event_loop = asyncio.get_running_loop()
# Signal the engine is started and make it visible to other threads.
with self._llm_engine_start_cv:
self._llm_engine = engine
self._llm_engine_start_cv.notify_all()
logger.info("Engine loaded and ready to serve...")
# Wait for the engine shutdown signal.
await self._llm_engine_shutdown_event.wait()
# Stop the publishing threads
if self.publish_stats_thread and self.publish_stats_thread.is_alive():
self.publish_stats_thread.stop()
self.publish_stats_thread.join()
if (
self.publish_kv_cache_events_thread
and self.publish_kv_cache_events_thread.is_alive()
):
self.publish_kv_cache_events_thread.stop()
self.publish_kv_cache_events_thread.join()
# Wait for the ongoing requests to complete.
while self._ongoing_request_count > 0:
logger.info(
"Awaiting remaining {} requests".format(
self._ongoing_request_count
)
)
await asyncio.sleep(1)
# Cancel all tasks in the event loop.
for task in asyncio.all_tasks(loop=self._event_loop):
if task is not asyncio.current_task():
task.cancel()
except Exception as e:
# Signal and pass the exception back via the engine variable if the engine
# failed to start. If the engine has started, re-raise the exception.
with self._llm_engine_start_cv:
if self._llm_engine is None:
self._llm_engine = e
self._llm_engine_start_cv.notify_all()
return
raise e
self._llm_engine = None
logger.info("Shutdown complete")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_worker
from .protocol import Request
@dynamo_worker()
async def worker(
runtime: DistributedRuntime,
component: str,
prompt: str,
max_tokens: int,
temperature: float,
streaming: bool,
):
"""
Instantiate a `backend` client and call the `generate` endpoint
"""
# create client
client = (
await runtime.namespace("dynamo")
.component(component)
.endpoint("generate")
.client()
)
# list the endpoints
print(client.endpoint_ids())
# issue request
tasks = []
for _ in range(1):
tasks.append(
client.generate(
Request(
prompt=prompt,
sampling_params={
"temperature": temperature,
"max_tokens": max_tokens,
},
streaming=streaming,
).model_dump_json()
)
)
streams = await asyncio.gather(*tasks)
# process response
for stream in streams:
async for resp in stream:
print(resp)
if __name__ == "__main__":
uvloop.install()
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="what is the capital of france?")
parser.add_argument("--max-tokens", type=int, default=10)
parser.add_argument("--temperature", type=float, default=0.5)
parser.add_argument("--streaming", type=bool, default=True)
parser.add_argument(
"--component", type=str, default="router", help="component to send request to"
)
args = parser.parse_args()
asyncio.run(
worker(
args.component,
args.prompt,
args.max_tokens,
args.temperature,
args.streaming,
)
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any, Dict, List, TypedDict, Union
from common.protocol import DisaggChatCompletionStreamResponse
from openai.types.chat import ChatCompletionMessageParam
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatMessage,
DeltaMessage,
FunctionCall,
ToolCall,
UsageInfo,
)
from transformers import AutoTokenizer
class ConversationMessage(TypedDict):
role: str
content: str
def parse_chat_message_content(
message: ChatCompletionMessageParam,
) -> Union[ConversationMessage, List[ConversationMessage], List[None]]:
role = message["role"]
content = message.get("content")
if content is None:
return []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)]
texts: List[str] = []
for part in content:
part_type = part["type"]
if part_type == "text":
text = part["text"] # type: ignore
texts.append(text)
else:
raise NotImplementedError(f"{part_type} is not supported")
text_prompt = "\n".join(texts)
return [ConversationMessage(role=role, content=text_prompt)]
class ChatProcessor:
def __init__(
self, model: str, tokenizer: AutoTokenizer, request: ChatCompletionRequest
):
self.model = model
self.tokenizer = tokenizer
self.request = request
self.num_choices = 1 if self.request.n is None else self.request.n
self.finish_reason_sent = [False] * self.num_choices
self.role = self._get_role(self.request)
def _get_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
role = "assistant"
else:
role = request.messages[-1]["role"]
return role
def _stream_usage_info(
self, request: ChatCompletionRequest, prompt_tokens: int, completion_tokens: int
):
if (
request.stream_options
and request.stream_options.include_usage
and request.stream_options.continuous_usage_stats
):
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
usage = None
return usage
def _create_logprobs(
self, token_ids: List[int], logprobs: List[float]
) -> ChatCompletionLogProbs:
assert len(token_ids) == len(
logprobs
), "token_ids and logprobs have different lengths"
content: List[ChatCompletionLogProbsContent] = []
for token_id, logprob in zip(token_ids, logprobs):
token = self.tokenizer.decode(token_id)
# returning multiple logprobs is not supported
first_logprob = ChatCompletionLogProbsContent(
token=token,
# NOTE: min logprob -9999.0 for probabilities extremely close to 0
logprob=max(logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")),
)
content.append(first_logprob)
chat_logprobs = ChatCompletionLogProbs(content=content)
return chat_logprobs
def get_chat_stream_response(
self,
request_id: str,
res: RequestOutput,
first_iteration: bool,
) -> DisaggChatCompletionStreamResponse:
def get_first_chat(
num_tokens: int, role: str | None = None, content: str | None = None
):
for i in range(self.num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role, content=content),
finish_reason=None,
)
chunk = DisaggChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion.chunk",
choices=[choice_data],
model=self.model,
)
chunk.usage = self._stream_usage_info(
self.request, num_tokens, completion_tokens=0
)
return chunk
prompt_tokens = len(res.prompt_token_ids)
if first_iteration:
return get_first_chat(prompt_tokens, role=self.role)
for output in res.outputs:
i = output.index
if self.finish_reason_sent[i]:
continue
delta_text = output.text_diff
if (
self.request.tool_choice
and type(self.request.tool_choice) is ChatCompletionNamedToolChoiceParam
):
delta_message = DeltaMessage(
tool_calls=[
ToolCall(
function=FunctionCall(
name=self.request.tool_choice.function.name,
arguments=delta_text,
)
)
]
)
else:
delta_message = DeltaMessage(content=delta_text)
choice = ChatCompletionResponseStreamChoice(
index=i, delta=delta_message, finish_reason=None
)
if self.request.logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = self._create_logprobs(token_ids, logprobs)
if output.finish_reason is not None:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
self.finish_reason_sent[i] = True
chunk = DisaggChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion.chunk",
choices=[choice],
model=self.model,
)
chunk.usage = self._stream_usage_info(
self.request, prompt_tokens, output.length
)
return chunk
def create_final_stream_response(
self,
request_id: str,
final_result: RequestOutput,
) -> DisaggChatCompletionStreamResponse:
prompt_tokens = len(final_result.prompt_token_ids)
completion_tokens = sum(output.length for output in final_result.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = DisaggChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion",
choices=[],
model=self.model,
usage=final_usage,
)
return final_usage_chunk
async def create_chat_response(
self,
request: ChatCompletionRequest,
conversation: List[Dict[str, Any]],
model: str,
promise: RequestOutput,
) -> ChatCompletionResponse:
await promise.aresult()
choices: List[ChatCompletionResponseChoice] = []
role = self._get_role(request)
for output in promise.outputs:
if request.tool_choice and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam
):
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text,
)
)
],
)
else:
message = ChatMessage(role=role, content=output.text)
choice = ChatCompletionResponseChoice(
index=output.index,
message=message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
if request.logprobs:
choice.logprobs = self._create_logprobs(
output.token_ids, output.logprobs
)
choices.append(choice)
if request.echo:
last_msg_content = ""
if (
conversation
and conversation[-1].get("content")
and conversation[-1].get("role") == role
):
last_msg_content = conversation[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(promise.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in promise.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
model=model,
choices=choices,
usage=usage,
)
return response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import signal
import uuid
from common.base_engine import BaseTensorrtLLMEngine
from common.processor import merge_promises, parse_chat_message_content
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.logger import logger
logger.set_level("debug")
async def chat_generator(engine: BaseTensorrtLLMEngine, request):
if engine._llm_engine is None:
raise RuntimeError("Engine not initialized")
logger.debug(f"Received chat request: {request}")
request_id = str(uuid.uuid4())
engine._ongoing_request_count += 1
try:
conversation = []
for message in request.messages:
conversation.extend(parse_chat_message_content(message))
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
prompt: str = engine._tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
sampling_params = request.to_sampling_params()
promise = engine._llm_engine.generate_async(
prompt,
sampling_params,
streaming=request.stream,
)
# NOTE: somehow stream and non-stream is working with the same path
response_generator = engine.chat_processor.stream_response(
request, request_id, conversation, promise
)
async for response in response_generator:
yield response
engine._ongoing_request_count -= 1
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
async def completion_generator(engine: BaseTensorrtLLMEngine, request):
if engine._llm_engine is None:
raise RuntimeError("Engine not initialized")
engine._ongoing_request_count += 1
logger.debug(f"Received completion request: {request}")
if isinstance(request.prompt, str) or (
isinstance(request.prompt, list) and isinstance(request.prompt[0], int)
):
prompts = [request.prompt]
else:
prompts = request.prompt
promises = []
sampling_params = request.to_sampling_params()
try:
for prompt in prompts:
promise = engine._llm_engine.generate_async(
prompt,
sampling_params,
streaming=request.stream,
)
promises.append(promise)
generator = merge_promises(promises)
num_choices = len(prompts) if request.n is None else len(prompts) * request.n
# NOTE: always send `stream: true` to the worker, and decide whether to aggregate or not before sending the response back to client.
response_generator = engine.completions_processor.create_completion_generator(
request, generator, num_choices
)
async for response in response_generator:
yield json.loads(response)
engine._ongoing_request_count -= 1
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
from ctypes import c_char_p, c_int64, c_uint32
from tensorrt_llm.logger import logger
logger.set_level("debug")
class DynamoResult:
OK = 0
ERR = 1
class KVCacheEventPublisher:
def __init__(self, namespace: str, component: str, worker_id: int, lib_path: str):
self.lib = None
try:
self.lib = ctypes.CDLL(lib_path)
self.lib.dynamo_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
self.lib.dynamo_llm_init.restype = c_uint32
result = self.lib.dynamo_llm_init(
namespace.encode(), component.encode(), worker_id
)
if result == DynamoResult.OK:
logger.info(
"KVCacheEventPublisher initialized successfully. Ready to publish KV Cache Events"
)
else:
logger.info("KVCacheEventPublisher initialization failed!")
except Exception as e:
print(f"Failed to load {lib_path}")
raise e
self.lib.dynamo_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
ctypes.POINTER(ctypes.c_uint64), # parent_hash
ctypes.c_uint64, # lora_id
]
self.lib.dynamo_kv_event_publish_stored.restype = (
ctypes.c_uint32
) # dynamo_llm_result_t
self.lib.dynamo_kv_event_publish_removed.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
]
self.lib.dynamo_kv_event_publish_removed.restype = (
ctypes.c_uint32
) # dynamo_llm_result_t
def stored_event(self, event_id, parent_hash, block_hashes, token_ids, lora_id):
if self.lib is None:
logger.error("KVCacheEventPublisher not initialized!")
return
logger.debug(
f"Stored event: {event_id}, parent_hash: {parent_hash}, block_hashes: {block_hashes}, token_ids: {token_ids}"
)
parent_hash = (
(ctypes.c_uint64 * 1)(parent_hash) if parent_hash is not None else None
)
block_hash_arr = (ctypes.c_uint64 * len(block_hashes))(*block_hashes)
block_hash_len = len(block_hashes)
token_ids_arr = (ctypes.c_uint32 * len(token_ids))(*token_ids)
num_block_tokens = (ctypes.c_size_t * 1)(len(token_ids))
# Publish the event
# TODO: Currently, lora_id is not available in the stored events.
result = self.lib.dynamo_kv_event_publish_stored(
event_id, # uint64_t event_id
token_ids_arr, # const uint32_t *token_ids
num_block_tokens, # const uintptr_t *num_block_tokens
block_hash_arr, # const uint64_t *block_ids
block_hash_len, # uintptr_t num_blocks
parent_hash, # const uint64_t *parent_hash
lora_id, # uint64_t lora_id
)
if result == DynamoResult.OK:
logger.debug(f"Store - Published KV Event: {block_hashes}")
else:
logger.error(f"Store - Failed to Publish KV Event: {block_hashes}")
def removed_event(self, event_id, block_hashes):
if self.lib is None:
logger.error("KVCacheEventPublisher not initialized!")
return
result = self.lib.dynamo_kv_event_publish_removed(
event_id,
(ctypes.c_uint64 * len(block_hashes))(*block_hashes),
(ctypes.c_size_t * 1)(len(block_hashes)),
)
if result == DynamoResult.OK:
logger.debug(f"Remove - Published KV Event: {block_hashes}")
else:
logger.error(f"Remove - Failed to Publish KV Event: {block_hashes}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Tuple
import yaml
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import KvCacheConfig
@dataclass
class LLMAPIConfig:
def __init__(
self,
model_name: str,
model_path: str | None = None,
pytorch_backend_config: PyTorchConfig | None = None,
kv_cache_config: KvCacheConfig | None = None,
**kwargs,
):
self.model_name = model_name
self.model_path = model_path
self.pytorch_backend_config = pytorch_backend_config
self.kv_cache_config = kv_cache_config
self.extra_args = kwargs
def to_dict(self) -> Dict[str, Any]:
data = {
"pytorch_backend_config": self.pytorch_backend_config,
"kv_cache_config": self.kv_cache_config,
}
if self.extra_args:
data.update(self.extra_args)
return data
def update_sub_configs(self, other_config: Dict[str, Any]):
if "pytorch_backend_config" in other_config:
self.pytorch_backend_config = PyTorchConfig(
**other_config["pytorch_backend_config"]
)
self.extra_args.pop("pytorch_backend_config", None)
if "kv_cache_config" in other_config:
self.kv_cache_config = KvCacheConfig(**other_config["kv_cache_config"])
self.extra_args.pop("kv_cache_config", None)
def _get_llm_args(engine_config):
# Only do model validation checks and leave other checks to LLMAPI
if "model_name" not in engine_config:
raise ValueError("Model name is required in the TRT-LLM engine config.")
if engine_config.get("model_path", ""):
if os.path.exists(engine_config.get("model_path", "")):
engine_config["model_path"] = Path(engine_config["model_path"])
else:
raise ValueError(f"Model path {engine_config['model_path']} does not exist")
model_name = engine_config["model_name"]
model_path = engine_config.get("model_path", None)
engine_config.pop("model_name")
engine_config.pop("model_path", None)
# Store all other args as kwargs
llm_api_config = LLMAPIConfig(
model_name=model_name,
model_path=model_path,
**engine_config,
)
# Parse supported sub configs and remove from kwargs
llm_api_config.update_sub_configs(engine_config)
return llm_api_config
def _init_engine_args(engine_args_filepath):
"""Initialize engine arguments from config file."""
if not os.path.isfile(engine_args_filepath):
raise ValueError(
"'YAML file containing TRT-LLM engine args must be provided in when launching the worker."
)
try:
with open(engine_args_filepath) as file:
trtllm_engine_config = yaml.safe_load(file)
except yaml.YAMLError as e:
raise RuntimeError(f"Failed to parse engine config: {e}")
return _get_llm_args(trtllm_engine_config)
def parse_tensorrt_llm_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]]:
parser = argparse.ArgumentParser(description="A TensorRT-LLM Worker parser")
parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file"
)
parser.add_argument(
"--llmapi-disaggregated-config",
"-c",
type=str,
help="Path to the llmapi disaggregated config file",
default=None,
)
parser.add_argument(
"--publish-kv-cache-events",
action="store_true",
help="Publish KV cache events from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
)
parser.add_argument(
"--publish-stats",
action="store_true",
help="Publish stats from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
)
parser.add_argument(
"--kv-block-size",
type=int,
help="KV block size for TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
default=64,
)
args = parser.parse_args()
return (args, _init_engine_args(args.engine_args))
def parse_dynamo_run_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]]:
parser = argparse.ArgumentParser(
description="A TensorRT-LLM Dynamo-run engine parser"
)
parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file"
)
# Disaggregated mode is not supported in dynamo-run launcher yet.
# parser.add_argument(
# "--llmapi-disaggregated-config",
# "-c",
# type=str,
# help="Path to the llmapi disaggregated config file",
# default=None,
# )
parser.add_argument(
"--publish-kv-cache-events",
action="store_true",
help="Publish KV cache events from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
)
parser.add_argument(
"--publish-stats",
action="store_true",
help="Publish stats from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
)
args, _ = parser.parse_known_args()
return (args, _init_engine_args(args.engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import time
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
List,
Tuple,
TypedDict,
Union,
)
from common.protocol import (
DisaggCompletionResponseStreamChoice,
DisaggCompletionStreamResponse,
DisaggregatedTypeConverter,
)
from openai.types.chat import ChatCompletionMessageParam
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
DeltaMessage,
FunctionCall,
ToolCall,
UsageInfo,
)
from transformers import AutoTokenizer
logger.set_level("debug")
class ConversationMessage(TypedDict):
role: str
content: str
def parse_chat_message_content(
message: ChatCompletionMessageParam,
) -> Union[ConversationMessage, List[ConversationMessage], List[None]]:
role = message["role"]
content = message.get("content")
if content is None:
return []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)]
texts: List[str] = []
for part in content:
part_type = part["type"]
if part_type == "text":
text = part["text"] # type: ignore
texts.append(text)
else:
raise NotImplementedError(f"{part_type} is not supported")
text_prompt = "\n".join(texts)
return [ConversationMessage(role=role, content=text_prompt)]
class ChatProcessor:
def __init__(self, model: str, tokenizer: AutoTokenizer):
self.model = model
self.tokenizer = tokenizer
def _get_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
role = "assistant"
else:
role = request.messages[-1]["role"]
return role
def _stream_usage_info(
self, request: ChatCompletionRequest, prompt_tokens: int, completion_tokens: int
):
if (
request.stream_options
and request.stream_options.include_usage
and request.stream_options.continuous_usage_stats
):
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
usage = None
return usage
def _create_logprobs(
self, token_ids: List[int], logprobs: List[float]
) -> ChatCompletionLogProbs:
assert len(token_ids) == len(
logprobs
), "token_ids and logprobs have different lengths"
content: List[ChatCompletionLogProbsContent] = []
for token_id, logprob in zip(token_ids, logprobs):
token = self.tokenizer.decode(token_id)
# returning multiple logprobs is not supported
first_logprob = ChatCompletionLogProbsContent(
token=token,
logprob=max(logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")),
)
content.append(first_logprob)
chat_logprobs = ChatCompletionLogProbs(content=content)
return chat_logprobs
async def _chat_stream_generator(
self,
request: ChatCompletionRequest,
request_id: str,
conversation: List[Dict[str, Any]],
promise: RequestOutput,
) -> AsyncGenerator[str, None]:
first_iteration = True
num_choices = 1 if request.n is None else request.n
finish_reason_sent = [False] * num_choices
role = self._get_role(request)
def yield_first_chat(
num_tokens: int, role: str | None = None, content: str | None = None
):
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role, content=content),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion.chunk",
choices=[choice_data],
model=self.model,
)
chunk.usage = self._stream_usage_info(request, num_tokens, 0)
data = chunk.model_dump_json(exclude_unset=True)
return data
async for res in promise:
prompt_tokens = len(res.prompt_token_ids)
if first_iteration:
yield f"data: {yield_first_chat(prompt_tokens, role=role)} \n\n"
if request.echo:
last_msg_content = ""
if (
conversation
and conversation[-1].get("content")
and conversation[-1].get("role") == role
):
last_msg_content = conversation[-1]["content"]
if last_msg_content:
yield f"data: {yield_first_chat(prompt_tokens, content=last_msg_content)}\n\n"
first_iteration = False
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_text = output.text_diff
if (
request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
):
delta_message = DeltaMessage(
tool_calls=[
ToolCall(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text,
)
)
]
)
else:
delta_message = DeltaMessage(content=delta_text)
choice = ChatCompletionResponseStreamChoice(
index=i, delta=delta_message, finish_reason=None
)
if request.logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = self._create_logprobs(token_ids, logprobs)
if output.finish_reason is not None:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
finish_reason_sent[i] = True
chunk = ChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion.chunk",
choices=[choice],
model=self.model,
)
chunk.usage = self._stream_usage_info(
request, prompt_tokens, output.length
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
if request.stream_options and request.stream_options.include_usage:
completion_tokens = sum(output.length for output in promise.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion",
choices=[],
model=self.model,
usage=final_usage,
)
final_usage_data = final_usage_chunk.model_dump_json()
yield f"data: {final_usage_data}\n\n"
yield "data: [DONE]\n\n"
async def stream_response(
self,
request: ChatCompletionRequest,
request_id: str,
conversation: List[Dict[str, Any]],
promise: RequestOutput,
) -> AsyncGenerator[str, None]:
assert request.stream, "Only stream is supported"
async for raw_response in self._chat_stream_generator(
request, request_id, conversation, promise
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
async def create_chat_response(
self,
request: ChatCompletionRequest,
conversation: List[Dict[str, Any]],
model: str,
promise: RequestOutput,
) -> ChatCompletionResponse:
await promise.aresult()
choices: List[ChatCompletionResponseChoice] = []
role = self._get_role(request)
for output in promise.outputs:
if request.tool_choice and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam
):
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text,
)
)
],
)
else:
message = ChatMessage(role=role, content=output.text)
choice = ChatCompletionResponseChoice(
index=output.index,
message=message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
if request.logprobs:
choice.logprobs = self._create_logprobs(
output.token_ids, output.logprobs
)
choices.append(choice)
if request.echo:
last_msg_content = ""
if (
conversation
and conversation[-1].get("content")
and conversation[-1].get("role") == role
):
last_msg_content = conversation[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(promise.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in promise.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
model=model,
choices=choices,
usage=usage,
)
return response
def merge_promises(
promises: List[RequestOutput],
) -> AsyncIterator[Tuple[int, RequestOutput]]:
outputs = asyncio.Queue() # type: ignore
finished = [False] * len(promises)
async def producer(i: int, promise: RequestOutput):
async for output in promise:
await outputs.put((i, output))
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, promise)) for i, promise in enumerate(promises)
]
async def consumer():
while not all(finished) or not outputs.empty():
item = await outputs.get()
yield item
await asyncio.gather(*_tasks)
return consumer()
class CompletionsProcessor:
def __init__(self, model: str):
self.model = model
def _post_process(self, request, prompt_idx, num_choices, requst_output):
res = []
echoed = [False] * num_choices
num_repsonse_per_request = 1 if request.n is None else request.n
for gen_idx, output in enumerate(requst_output.outputs):
response_idx = prompt_idx * num_repsonse_per_request + gen_idx
delta_text = output.text_diff
if request.echo and not echoed[response_idx]:
delta_text = request.prompt + delta_text
echoed[response_idx] = True
choice = DisaggCompletionResponseStreamChoice(
index=response_idx,
text=delta_text,
stop_reason=output.stop_reason,
finish_reason=output.finish_reason,
)
if output.disaggregated_params is not None:
choice.disaggregated_params = (
DisaggregatedTypeConverter.to_oai_disaggregated_params(
output.disaggregated_params
)
)
chunk = DisaggCompletionStreamResponse(
model=self.model,
choices=[choice],
)
res.append(chunk.model_dump_json())
return res
async def create_completion_generator(
self,
request: CompletionRequest,
generator: AsyncIterator[Tuple[int, RequestOutput]],
num_choices: int,
):
async for prompt_idx, requst_output in generator:
pp_res = self._post_process(request, prompt_idx, num_choices, requst_output)
for _p in pp_res:
yield _p
async def create_completion_response(
self,
request: CompletionRequest,
generator: AsyncIterator[Tuple[int, RequestOutput]],
num_choices: int,
):
choices = [None] * num_choices
num_repsonse_per_request = 1 if request.n is None else request.n
num_prompt_tokens = num_gen_tokens = 0
async for prompt_idx, request_output in generator:
num_prompt_tokens += len(request_output.prompt_token_ids)
for gen_idx, output in enumerate(request_output.outputs):
num_gen_tokens += len(output.token_ids)
output_text = output.text
if request.echo:
output_text = request_output.prompt + output_text
idx = prompt_idx * num_repsonse_per_request + gen_idx
disaggregated_params = CompletionResponseChoice.to_disaggregated_params(
output.disaggregated_params
)
choice = CompletionResponseChoice(
index=idx,
text=output_text,
stop_reason=output.stop_reason,
finish_reason=output.finish_reason,
disaggregated_params=disaggregated_params,
)
choices[idx] = choice
usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_gen_tokens,
total_tokens=num_gen_tokens + num_prompt_tokens,
)
response = CompletionResponse(
model=self.model,
choices=choices,
usage=usage_info,
)
return response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import time
import uuid
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
CompletionResponseStreamChoice,
DisaggregatedParams,
UsageInfo,
)
class Tokens(BaseModel):
tokens: list[int]
class Request(BaseModel):
prompt: str
sampling_params: dict
streaming: bool
class DisaggregatedTypeConverter:
@staticmethod
def to_llm_disaggregated_params(
disaggregated_params: DisaggregatedParams,
) -> LlmDisaggregatedParams:
if disaggregated_params is None:
return None
else:
opaque_state = (
base64.b64decode(disaggregated_params.encoded_opaque_state)
if disaggregated_params.encoded_opaque_state is not None
else None
)
return LlmDisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=opaque_state,
)
@staticmethod
def to_oai_disaggregated_params(
tllm_disagg_params: LlmDisaggregatedParams,
) -> DisaggregatedParams:
if tllm_disagg_params is None:
return None
else:
encoded_opaque_state = (
base64.b64encode(tllm_disagg_params.opaque_state).decode("utf-8")
if tllm_disagg_params is not None
else None
)
return DisaggregatedParams(
request_type=tllm_disagg_params.request_type,
first_gen_tokens=tllm_disagg_params.first_gen_tokens,
ctx_request_id=tllm_disagg_params.ctx_request_id,
encoded_opaque_state=encoded_opaque_state,
)
# Chat Completions
class DisaggChatCompletionRequest(ChatCompletionRequest):
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid4().hex)}")
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
class DisaggChatCompletionStreamResponse(ChatCompletionStreamResponse):
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
## Completions
class DisaggCompletionResponseStreamChoice(CompletionResponseStreamChoice):
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
class DisaggCompletionStreamResponse(BaseModel):
model_config = ConfigDict(extra="forbid")
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid4().hex)}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[DisaggCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import threading
import traceback
import weakref
from queue import Queue
from typing import Callable, Optional, Union
from tensorrt_llm.logger import logger
logger.set_level("info")
class ManagedThread(threading.Thread):
def __init__(
self,
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]],
error_queue: Optional[Queue] = None,
name: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs,
):
super().__init__(name=name)
self.task = task
self.error_queue = error_queue
self.kwargs = kwargs
self.loop = loop
self.daemon = True
self.stop_event = threading.Event()
def set_loop(self, loop: asyncio.AbstractEventLoop):
self.loop = loop
def run(self):
while not self.stop_event.is_set():
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]] = self.task
if isinstance(task, weakref.WeakMethod):
task = task()
if task is None:
# Normally, this should not happen.
logger.warning("WeakMethod is expired.")
break
if task is None:
break
try:
if self.loop is None:
logger.error("[ManagedThread] Loop not initialized!")
break
future = asyncio.run_coroutine_threadsafe(
task(**self.kwargs), self.loop
)
_ = future.result()
except Exception as e:
logger.error(
f"Error in thread {self.name}: {e}\n{traceback.format_exc()}"
)
if self.error_queue is not None:
self.error_queue.put(e)
logger.info(f"Thread {self.name} stopped.")
def stop(self):
self.stop_event.set()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import enum
import json
import traceback
from typing import AsyncIterator
import uvloop
from common.base_engine import ChatProcessorMixin
from common.parser import LLMAPIConfig, parse_tensorrt_llm_args
from common.processor import parse_chat_message_content
from common.protocol import (
DisaggChatCompletionRequest,
DisaggChatCompletionStreamResponse,
DisaggCompletionStreamResponse,
Tokens,
)
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import CompletionRequest, DisaggregatedParams
from dynamo.llm import KvRouter
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
logger.set_level("debug")
class EndpointType(enum.Enum):
chat = "chat"
completion = "completion"
class Scheduler:
def __init__(self, kv_router: KvRouter):
self.kv_router = kv_router
@dynamo_endpoint(Tokens, str)
async def generate(self, request) -> AsyncIterator[str]:
lora_id = 0
worker_id = None
try:
worker_id = await self.kv_router.schedule(request.tokens, lora_id)
except Exception:
logger.warning(f"Error during worker selection: {traceback.format_exc()}")
worker_id = ""
logger.debug(f"Scheduling to worker_id: {worker_id}")
yield str(worker_id)
class Router(ChatProcessorMixin):
def __init__(
self,
ctx_chat_client,
gen_chat_client,
ctx_completion_client,
gen_completion_client,
scheduler: Scheduler,
engine_config: LLMAPIConfig,
):
self.ctx_chat_client = ctx_chat_client
self.gen_chat_client = gen_chat_client
self.ctx_completion_client = ctx_completion_client
self.gen_completion_client = gen_completion_client
self.scheduler = scheduler
# allows to use tokenizer
super().__init__(engine_config)
logger.info("INITIALIZED ROUTER")
async def _get_ctx_resp(self, request, ctx_client, endpoint_type: EndpointType):
logger.debug(f"Received request {request}")
# NOTE: this will increase TTFT since we are encoding the prompt here
# prompt is also encoded in the worker.
# TODO: we need to implement our own request processing and protocols to send only token ids to llmapi worker.
if endpoint_type == EndpointType.completion:
token_ids = self._tokenizer.encode(request.prompt)
else:
conversation = []
for message in request.messages:
conversation.extend(parse_chat_message_content(message))
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
token_ids = self._tokenizer.apply_chat_template(
conversation=conversation,
tokenize=True,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
worker_id_generator: AsyncIterator = self.scheduler.generate(
Tokens(tokens=token_ids).model_dump_json()
)
worker_id = (
await worker_id_generator.__anext__()
) # only one worker id is returned
request.max_completion_tokens = 1
request.disaggregated_params = DisaggregatedParams(request_type="context_only")
logger.debug(f"[router] Sending request to context server: {request}")
if worker_id == "":
ctx_resp = [
resp
async for resp in await ctx_client.random(request.model_dump_json())
]
else:
ctx_resp = [
resp
async for resp in await ctx_client.direct(
request.model_dump_json(), int(worker_id)
)
]
if len(ctx_resp) > 1:
raise ValueError(
"Context server returned more than one response. This is currently not supported in disaggregated server."
)
logger.debug(
f"[router] received response from context server: {ctx_resp[0].data()}"
)
return ctx_resp[0].data()
# TODO (shreyasm): The only reason we cant further combine the two methods below is
# because the disagg params are in different locations.
# Disagg params should be in under the choices field in the response object.
# This is the case for completions but not for chat.
@dynamo_endpoint(CompletionRequest, DisaggCompletionStreamResponse)
async def generate_completion(self, request):
# These settings are needed to satisfy request checks.
request.skip_special_tokens = False
request.add_special_tokens = False
request.spaces_between_special_tokens = False
gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(
request, self.ctx_completion_client, EndpointType.completion
)
ctx_resp_obj = DisaggCompletionStreamResponse.model_validate(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate(
ctx_resp_obj.choices[0].disaggregated_params
)
gen_req.disaggregated_params.request_type = "generation_only"
if request.stream:
yield json.loads(
ctx_resp_obj.model_dump_json(
exclude_unset=True, exclude={"disaggregated_params"}
)
)
logger.debug(f"[router] Sending request to generation server: {gen_req}")
async for response in await self.gen_completion_client.round_robin(
gen_req.model_dump_json()
):
gen_resp_obj = DisaggCompletionStreamResponse.model_validate(
response.data()
)
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
@dynamo_endpoint(DisaggChatCompletionRequest, DisaggChatCompletionStreamResponse)
async def generate_chat(self, request):
# These settings are needed to satisfy request checks.
request.skip_special_tokens = False
request.add_special_tokens = False
request.spaces_between_special_tokens = False
gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(
request, self.ctx_chat_client, EndpointType.chat
)
ctx_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate(
ctx_resp_obj.disaggregated_params
)
gen_req.disaggregated_params.request_type = "generation_only"
if request.stream:
yield json.loads(
ctx_resp_obj.model_dump_json(
exclude_unset=True, exclude={"disaggregated_params"}
)
)
logger.debug(f"[router] Sending request to generation server: {gen_req}")
async for response in await self.gen_chat_client.round_robin(
gen_req.model_dump_json()
):
gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(
response.data()
)
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
@dynamo_worker()
async def worker(runtime: DistributedRuntime, args, engine_config):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("dynamo").component("router")
await component.create_service()
ctx_completion_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-ctx")
.endpoint("completions")
.client()
)
gen_completion_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-gen")
.endpoint("completions")
.client()
)
ctx_chat_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-ctx")
.endpoint("chat/completions")
.client()
)
gen_chat_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-gen")
.endpoint("chat/completions")
.client()
)
# Only listen to context server for now
kv_listener = runtime.namespace("dynamo").component("tensorrt-llm-ctx")
await kv_listener.create_service()
kv_router = KvRouter(runtime, kv_listener, args.kv_block_size)
completions_endpoint = component.endpoint("completions")
chat_endpoint = component.endpoint("chat/completions")
scheduler = Scheduler(kv_router)
router = Router(
ctx_chat_client,
gen_chat_client,
ctx_completion_client,
gen_completion_client,
scheduler,
engine_config,
)
await asyncio.gather(
completions_endpoint.serve_endpoint(router.generate_completion),
chat_endpoint.serve_endpoint(router.generate_chat),
)
if __name__ == "__main__":
uvloop.install()
args, engine_config = parse_tensorrt_llm_args()
asyncio.run(worker(args, engine_config))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
context_servers:
num_instances: 2
gpu_fraction: 0.25
tp_size: 2
pp_size: 1
urls:
- "node1:8001"
- "node1:8002"
generation_servers:
num_instances: 2
gpu_fraction: 0.25
tp_size: 2
pp_size: 1
urls:
- "node2:8003"
- "node2:8004"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This will overwrite the llm_api_config.yaml
hostname: localhost
port: 8000
context_servers:
num_instances: 1
tensor_parallel_size: 1
moe_expert_parallel_size: 1
kv_cache_config:
free_gpu_memory_fraction: 0.45
pytorch_backend_config:
enable_overlap_scheduler: false
use_cuda_graph: false
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
moe_expert_parallel_size: 1
kv_cache_config:
free_gpu_memory_fraction: 0.95
pytorch_backend_config:
enable_overlap_scheduler: true
use_cuda_graph: true
urls:
- "localhost:8002"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This will overwrite the llm_api_config.yaml
hostname: localhost
port: 8000
context_servers:
num_instances: 4
tensor_parallel_size: 1
moe_expert_parallel_size: 1
kv_cache_config:
free_gpu_memory_fraction: 0.45
event_buffer_max_size: 1024
enable_block_reuse: true
pytorch_backend_config:
enable_overlap_scheduler: false
use_cuda_graph: false
enable_iter_perf_stats: true
urls:
- "localhost:8001"
- "localhost:8002"
- "localhost:8003"
- "localhost:8004"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
moe_expert_parallel_size: 1
kv_cache_config:
free_gpu_memory_fraction: 0.95
pytorch_backend_config:
enable_overlap_scheduler: true
use_cuda_graph: true
urls:
- "localhost:8005"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment