Commit 6eb10540 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: Add TensorRT-LLM example for dynamo serve/run (#456)


Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent b865bd4f
......@@ -17,6 +17,7 @@
**/*.plan
**/*.onnx
**/*.plan
**/*.etcd
**/.cache/*
**/*onnx*
# Engine must be allowed because code contains dynamo_engine.py
......@@ -37,4 +38,4 @@
**/*backup*/
.dockerignore
**/target/*
**/*safetensors
\ No newline at end of file
**/*safetensors
......@@ -17,6 +17,7 @@
**/[Oo][Uu][Tt]/
**/[Rr]elease/
**/[Tt][Mm][Pp]/
**/*.etcd
.markdownlint.json
CMakeCache.txt
......@@ -80,4 +81,6 @@ __pycache__/
### Helm ###
*.tgz
Chart.lock
generated-values.yaml
\ No newline at end of file
generated-values.yaml
TensorRT-LLM
\ No newline at end of file
......@@ -51,10 +51,8 @@ DOCKERFILE=${SOURCE_DIR}/Dockerfile
BUILD_CONTEXT=$(dirname "$(readlink -f "$SOURCE_DIR")")
# Base Images
TENSORRTLLM_BASE_VERSION=25.01
# FIXME: Need a public image for public consumption
TENSORRTLLM_BASE_IMAGE="gitlab-master.nvidia.com:5005/dl/dgx/tritonserver/tensorrt-llm/amd64"
TENSORRTLLM_BASE_IMAGE_TAG=krish-fix-trtllm-build.23766174
TENSORRTLLM_BASE_IMAGE=tensorrt_llm/release
TENSORRTLLM_BASE_IMAGE_TAG=latest
TENSORRTLLM_PIP_WHEEL_PATH=""
VLLM_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
......@@ -354,6 +352,19 @@ if [ -z "$RUN_PREFIX" ]; then
set -x
fi
# Check if the TensorRT-LLM base image exists
if [[ $FRAMEWORK == "TENSORRTLLM" ]]; then
if docker inspect --type=image "$BASE_IMAGE:$BASE_IMAGE_TAG" > /dev/null 2>&1; then
echo "Image '$BASE_IMAGE:$BASE_IMAGE_TAG' is found."
else
echo "Image '$BASE_IMAGE:$BASE_IMAGE_TAG' is not found." >&2
echo "Please build the TensorRT-LLM base image first. Run ./build_trtllm_base_image.sh" >&2
echo "or use --base-image and --base-image-tag to an existing TensorRT-LLM base image." >&2
echo "See https://nvidia.github.io/TensorRT-LLM/installation/build-from-source-linux.html for more information." >&2
exit 1
fi
fi
$RUN_PREFIX docker build -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO $TAG $LATEST_TAG $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE
{ set +x; } 2>/dev/null
......
#!/bin/bash -e
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Build the TRT-LLM base image.
# This script builds the TRT-LLM base image for Dynamo with TensorRT-LLM.
TRTLLM_COMMIT=9b931c0f6
while getopts "c:" opt; do
case ${opt} in
c) TRTLLM_COMMIT=$OPTARG ;;
*) echo "Invalid option" ;;
esac
done
(cd /tmp && \
# Clone the TensorRT-LLM repository.
if [ ! -d "TensorRT-LLM" ]; then
git clone https://github.com/NVIDIA/TensorRT-LLM.git
fi
cd TensorRT-LLM
# Checkout the specified commit.
git checkout $TRTLLM_COMMIT
# Update the submodules.
git submodule update --init --recursive
git lfs pull
# Build the TRT-LLM base image.
make -C docker release_build)
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# LLM Deployment Examples using TensorRT-LLM
This directory contains examples and reference implementations for deploying Large Language Models (LLMs) in various configurations using TensorRT-LLM.
## Deployment Architectures
See [deployment architectures](../llm/README.md#deployment-architectures) to learn about the general idea of the architecture.
Note that this TensorRT-LLM version does not support all the options yet.
### Prerequisites
Start required services (etcd and NATS) using [Docker Compose](../../deploy/docker-compose.yml)
```bash
docker compose -f deploy/docker-compose.yml up -d
```
### Build docker
#### Step 1: Build TensorRT-LLM base container image
Because of the known issue of C++11 ABI compatibility within the NGC pytorch container, we rebuild TensorRT-LLM from source.
See [here](https://nvidia.github.io/TensorRT-LLM/installation/linux.html) for more informantion.
Use the helper script to build a TensorRT-LLM container base image. The script uses a specific commit id from TensorRT-LLM main branch.
```bash
./container/build_trtllm_base_image.sh
```
For more information see [here](https://nvidia.github.io/TensorRT-LLM/installation/build-from-source-linux.html#option-1-build-tensorrt-llm-in-one-step) for more details on building from source.
If you already have a TensorRT-LLM container image, you can skip this step.
#### Step 2: Build the Dynamo container
```
./container/build.sh --framework tensorrtllm
```
This build script internally points to the base container image built with step 1. If you skipped previous step because you already have the container image available, you can run the build script with that image as a base.
```bash
# Build dynamo image with other TRTLLM base image.
./container/build.sh --framework TENSORRTLLM --base-image <trtllm-base-image> --base-image-tag <trtllm-base-image-tag>
```
### Run container
```
./container/run.sh --framework tensorrtllm -it
```
## Run Deployment
### Example architectures
#### Aggregated serving
```bash
cd /workspace/examples/tensorrt_llm
dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml
```
#### Aggregated serving with KV Routing
```bash
cd /workspace/examples/tensorrt_llm
dynamo serve graphs.agg_router:Frontend -f ./configs/agg_router.yaml
```
#### Aggregated serving using Dynamo Run
```bash
cd /workspace/examples/tensorrt_llm
dynamo run out=pystr:./engines/agg_engine.py -- --engine_args ./configs/llm_api_config.yaml
```
The above command should load the model specified in `llm_api_config.yaml` and start accepting
text input from the client. For more details on the `dynamo run` command, please refer to the
[dynamo run](/launch/README.md#python-bring-your-own-engine) documentation.
Currently only aggregated deployment option is supported by `dynamo run` for TensorRT-LLM.
Adding support for disaggregated deployment is under development. This does *not* require
any other pre-requisites mentioned in the [Prerequisites](#prerequisites) section.
<!--
This is work in progress and will be enabled soon.
#### Disaggregated serving
```bash
cd /workspace/examples/llm
dynamo serve graphs.disagg:Frontend -f ./configs/disagg.yaml
```
#### Disaggregated serving with KV Routing
```bash
cd /workspace/examples/llm
dynamo serve graphs.disagg_router:Frontend -f ./configs/disagg_router.yaml
```
-->
### Client
See [client](../llm/README.md#client) section to learn how to send request to the deployment.
### Close deployment
See [close deployment](../llm/README.md#close-deployment) section to learn about how to close the deployment.
Remaining tasks:
- [ ] Add support for the disaggregated serving.
- [ ] Add integration test coverage.
- [ ] Add instructions for benchmarking.
- [ ] Add multi-node support.
- [ ] Merge the code base with llm example to reduce the code duplication.
- [ ] Use processor from dynamo-llm framework.
- [ ] Explore NIXL integration with TensorRT-LLM.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import threading
from contextlib import asynccontextmanager
from dataclasses import dataclass
from queue import Queue
from typing import Any, Optional
from common.chat_processor import ChatProcessor, CompletionsProcessor
from common.parser import LLMAPIConfig
from common.utils import ManagedThread
from tensorrt_llm._torch import LLM
from tensorrt_llm.logger import logger
from transformers import AutoTokenizer
from dynamo.llm import KvMetricsPublisher
from .kv_cache_event_publisher import KVCacheEventPublisher
logger.set_level("info")
class ChatProcessorMixin:
def __init__(self, engine_config: LLMAPIConfig):
self._engine_config = engine_config
logger.info(f"Using LLM API config: {self._engine_config.to_dict()}")
# model name for chat processor
self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}")
# model for LLMAPI input
self._model = self._model_name
if self._engine_config.model_path:
self._model = self._engine_config.model_path
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_path
)
logger.info(f"Using model from path: {self._engine_config.model_path}")
else:
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_name
)
if self._engine_config.extra_args.get("tokenizer", None):
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.extra_args.get("tokenizer", None)
)
self.chat_processor = ChatProcessor(self._model_name, self._tokenizer)
self.completions_processor = CompletionsProcessor(
self._model_name, self._tokenizer
)
@dataclass
class TensorrtLLMEngineConfig:
namespace_str: str = "dynamo"
component_str: str = "tensorrt-llm"
engine_config: LLMAPIConfig = None
worker_id: Optional[str] = None
kv_metrics_publisher: Optional[KvMetricsPublisher] = None
publish_stats: bool = False
publish_kv_cache_events: bool = False
# default block size is 32 for pytorch backend
kv_block_size: int = 32
class BaseTensorrtLLMEngine(ChatProcessorMixin):
def __init__(
self,
trt_llm_engine_config: TensorrtLLMEngineConfig,
):
super().__init__(trt_llm_engine_config.engine_config)
self._namespace_str = trt_llm_engine_config.namespace_str
self._component_str = trt_llm_engine_config.component_str
self._worker_id = trt_llm_engine_config.worker_id
self._kv_metrics_publisher = trt_llm_engine_config.kv_metrics_publisher
self._publish_stats = trt_llm_engine_config.publish_stats
self._publish_kv_cache_events = trt_llm_engine_config.publish_kv_cache_events
self._kv_block_size = trt_llm_engine_config.kv_block_size
self._error_queue: Optional[Queue] = None
self._init_engine()
def _init_engine(self):
logger.info("Initializing engine")
# Run the engine in a separate thread running the AsyncIO event loop.
self._llm_engine: Optional[Any] = None
self._llm_engine_start_cv = threading.Condition()
self._llm_engine_shutdown_event = asyncio.Event()
self._event_thread = threading.Thread(
target=asyncio.run, args=(self._run_llm_engine(),)
)
self.publish_kv_cache_events_thread = None
self.publish_stats_thread = None
self._event_thread.start()
with self._llm_engine_start_cv:
while self._llm_engine is None:
self._llm_engine_start_cv.wait()
# The 'threading.Thread()' will not raise the exception here should the engine
# failed to start, so the exception is passed back via the engine variable.
if isinstance(self._llm_engine, Exception):
e = self._llm_engine
logger.error(f"Failed to start engine: {e}")
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
raise e
self._error_queue = Queue()
try:
if self._publish_stats:
self._init_publish_metrics_thread()
if self._publish_kv_cache_events:
self._init_publish_kv_cache_events_thread()
except Exception as e:
logger.error(f"Failed to initialize publish metrics threads: {e}")
raise e
def _init_publish_metrics_thread(self):
# Need to publish stats once so that worker can be selected.
# Publishing some dummy values...
request_active_slots = 0
request_total_slots = 4
kv_active_block = 0
kv_total_blocks = 4
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
num_requests_waiting = 0
gpu_cache_usage_perc = 0.0
gpu_prefix_cache_hit_rate = 0.0
if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!")
return
self._kv_metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
# Prepare threads for publishing stats but don't start them yet.
# TRTLLM needs to start generating tokens first before stats
# can be retrieved.
self.publish_stats_thread = ManagedThread(
self.publish_stats_task,
error_queue=self._error_queue,
name="publish_stats_thread",
)
def _init_publish_kv_cache_events_thread(self):
if self._worker_id is None:
logger.error("Worker ID not initialized!")
return
# TODO: Use python bindings to publish kv cache events once they
# are available.
lib_path = "/opt/dynamo/bindings/lib/libdynamo_llm_capi.so"
self._kv_cache_events_publisher = KVCacheEventPublisher(
self._namespace_str,
self._component_str,
int(self._worker_id),
lib_path,
self._kv_block_size,
)
# Prepare threads for publishing kv cache events but don't start them yet.
# TRTLLM needs to start generating tokens first before kv cache events
# can be retrieved.
self.publish_kv_cache_events_thread = ManagedThread(
self.publish_kv_cache_events_task,
error_queue=self._error_queue,
name="publish_kv_cache_events_thread",
)
async def publish_stats_task(self):
"""
Publish stats to the metrics publisher.
"""
if self._llm_engine is None:
logger.error("LLM engine not initialized!")
return
if self._kv_metrics_publisher is None:
logger.error("KV metrics publisher not initialized!")
return False
stats = self._llm_engine.get_stats_async(timeout=5)
async for stat in stats:
request_active_slots = stat["numActiveRequests"]
request_total_slots = stat["maxNumActiveRequests"]
kv_active_block = stat["kvCacheStats"]["usedNumBlocks"]
kv_total_blocks = stat["kvCacheStats"]["maxNumBlocks"]
reused_blocks = stat["kvCacheStats"]["reusedBlocks"]
freeNumBlocks = stat["kvCacheStats"]["freeNumBlocks"]
allocTotalBlocks = stat["kvCacheStats"]["allocTotalBlocks"]
allocNewBlocks = stat["kvCacheStats"]["allocNewBlocks"]
# NOTE: num paused requests is always 0 when using guarantee no evict scheduler (default).
num_requests_waiting = (
stat["numQueuedRequests"]
+ stat["inflightBatchingStats"]["numPausedRequests"]
)
gpu_cache_usage_perc = allocTotalBlocks / kv_total_blocks
gpu_prefix_cache_hit_rate = stat["kvCacheStats"]["cacheHitRate"]
logger.debug(
f"Publishing stats: request_active_slots: {request_active_slots}, request_total_slots: {request_total_slots}, kv_active_block: {kv_active_block}, kv_total_blocks: {kv_total_blocks}, num_requests_waiting: {num_requests_waiting}, reused_blocks: {reused_blocks}, freeNumBlocks: {freeNumBlocks}, allocTotalBlocks: {allocTotalBlocks}, allocNewBlocks: {allocNewBlocks}, gpu_cache_usage_perc: {gpu_cache_usage_perc}, gpu_prefix_cache_hit_rate: {gpu_prefix_cache_hit_rate}"
)
self._kv_metrics_publisher.publish(
request_active_slots,
request_total_slots,
kv_active_block,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
return True
async def publish_kv_cache_events_task(self):
"""
Publish kv cache events to the events publisher.
"""
if self._llm_engine is None:
logger.error("LLM engine not initialized!")
return
events = self._llm_engine.get_kv_cache_events_async(timeout=5)
async for event_list in events:
for event in event_list:
data = event["data"]
if data["type"] == "stored":
parent_hash = data["parent_hash"]
for block in data["blocks"]:
tokens = []
for token in block["tokens"]:
tokens.append(int(token["token_id"]))
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
lora_id = data.get("lora_id", 0)
self._kv_cache_events_publisher.stored_event(
parent_hash,
block["block_hash"],
tokens,
lora_id,
)
elif data["type"] == "removed":
for block_hash in data["block_hashes"]:
self._kv_cache_events_publisher.removed_event(block_hash)
return True
def _start_threads(self):
if (
self.publish_kv_cache_events_thread
and not self.publish_kv_cache_events_thread.is_alive()
):
# [NOTE:] TRTLLM needs the stats to be collected on the same loop as the request handler.
self._stats_loop = asyncio.get_running_loop()
self.publish_kv_cache_events_thread.set_loop(self._stats_loop)
self.publish_kv_cache_events_thread.start()
logger.debug("Started kv cache events thread")
if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
self._stats_loop = asyncio.get_running_loop()
self.publish_stats_thread.set_loop(self._stats_loop)
self.publish_stats_thread.start()
logger.debug("Started stats thread")
async def _run_llm_engine(self):
# Counter to keep track of ongoing request counts.
self._ongoing_request_count = 0
@asynccontextmanager
async def async_llm_wrapper():
# Create LLM in a thread to avoid blocking
loop = asyncio.get_running_loop()
try:
llm = await loop.run_in_executor(
None,
lambda: LLM(model=self._model, **self._engine_config.to_dict()),
)
yield llm
finally:
if "llm" in locals():
# Run shutdown in a thread to avoid blocking
await loop.run_in_executor(None, llm.shutdown)
try:
async with async_llm_wrapper() as engine:
# Capture the engine event loop and make it visible to other threads.
self._event_loop = asyncio.get_running_loop()
# Signal the engine is started and make it visible to other threads.
with self._llm_engine_start_cv:
self._llm_engine = engine
self._llm_engine_start_cv.notify_all()
logger.info("Engine loaded and ready to serve...")
# Wait for the engine shutdown signal.
await self._llm_engine_shutdown_event.wait()
# Stop the publishing threads
if self.publish_stats_thread and self.publish_stats_thread.is_alive():
self.publish_stats_thread.stop()
self.publish_stats_thread.join()
if (
self.publish_kv_cache_events_thread
and self.publish_kv_cache_events_thread.is_alive()
):
self.publish_kv_cache_events_thread.stop()
self.publish_kv_cache_events_thread.join()
# Wait for the ongoing requests to complete.
while self._ongoing_request_count > 0:
logger.info(
"Awaiting remaining {} requests".format(
self._ongoing_request_count
)
)
await asyncio.sleep(1)
# Cancel all tasks in the event loop.
for task in asyncio.all_tasks(loop=self._event_loop):
if task is not asyncio.current_task():
task.cancel()
except Exception as e:
# Signal and pass the exception back via the engine variable if the engine
# failed to start. If the engine has started, re-raise the exception.
with self._llm_engine_start_cv:
if self._llm_engine is None:
self._llm_engine = e
self._llm_engine_start_cv.notify_all()
return
raise e
self._llm_engine = None
logger.info("Shutdown complete")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict
from typing import Any, Dict, List, Union
from common.protocol import (
DisaggregatedTypeConverter,
DynamoTRTLLMChatCompletionResponseStreamChoice,
DynamoTRTLLMChatCompletionStreamResponse,
DynamoTRTLLMCompletionResponseStreamChoice,
DynamoTRTLLMCompletionStreamResponse,
Tokens,
TRTLLMWorkerRequest,
TRTLLMWorkerResponse,
TRTLLMWorkerResponseOutput,
)
from common.utils import ConversationMessage, ServerType
from openai.types.chat import ChatCompletionMessageParam
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
DeltaMessage,
FunctionCall,
ToolCall,
UsageInfo,
)
from transformers import AutoTokenizer
logger.set_level("debug")
def parse_chat_message_content(
message: ChatCompletionMessageParam,
) -> Union[ConversationMessage, List[ConversationMessage], List[None]]:
role = message["role"]
content = message.get("content")
if content is None:
return []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)]
texts: List[str] = []
for part in content:
part_type = part["type"]
if part_type == "text":
text = part["text"] # type: ignore
texts.append(text)
else:
raise NotImplementedError(f"{part_type} is not supported")
text_prompt = "\n".join(texts)
return [ConversationMessage(role=role, content=text_prompt)]
class BaseChatProcessor:
def __init__(self, model: str, tokenizer: AutoTokenizer):
self.model = model
self.tokenizer = tokenizer
def _get_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
role = "assistant"
else:
role = request.messages[-1]["role"]
return role
def _stream_usage_info(
self, request: ChatCompletionRequest, prompt_tokens: int, completion_tokens: int
):
if (
request.stream_options
and request.stream_options.include_usage
and request.stream_options.continuous_usage_stats
):
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
usage = None
return usage
def _create_logprobs(
self, token_ids: List[int], logprobs: List[float]
) -> ChatCompletionLogProbs:
assert len(token_ids) == len(
logprobs
), "token_ids and logprobs have different lengths"
content: List[ChatCompletionLogProbsContent] = []
for token_id, logprob in zip(token_ids, logprobs):
token = self.tokenizer.decode(token_id)
# returning multiple logprobs is not supported
first_logprob = ChatCompletionLogProbsContent(
token=token,
# NOTE: min logprob -9999.0 for probabilities extremely close to 0
logprob=max(logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")),
)
content.append(first_logprob)
chat_logprobs = ChatCompletionLogProbs(content=content)
return chat_logprobs
class ChatProcessor(BaseChatProcessor):
def __init__(
self, model: str, tokenizer: AutoTokenizer, using_engine_generator: bool = False
):
super().__init__(model, tokenizer)
self.using_engine_generator = using_engine_generator
def yield_first_chat(
self,
request: ChatCompletionRequest,
request_id: str,
response: RequestOutput,
content: str | None = None,
):
role = self._get_role(request)
num_choices = 1 if request.n is None else request.n
num_tokens = len(response.prompt_token_ids)
content = response.outputs[0].text_diff
for i in range(num_choices):
choice = DynamoTRTLLMChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role, content=content),
finish_reason=None,
)
if response.outputs[0].disaggregated_params is not None:
choice.disaggregated_params = (
DisaggregatedTypeConverter.to_oai_disaggregated_params(
response.outputs[0].disaggregated_params
)
)
chunk = DynamoTRTLLMChatCompletionStreamResponse(
id=request_id,
choices=[choice],
model=self.model,
)
chunk.usage = self._stream_usage_info(request, num_tokens, 0)
return chunk.model_dump_json()
def create_chat_stream_response(
self,
request: ChatCompletionRequest,
request_id: str,
response: RequestOutput,
conversation: List[Dict[str, Any]],
first_iteration: bool = True,
) -> str:
num_choices = 1 if request.n is None else request.n
finish_reason_sent = [False] * num_choices
role = self._get_role(request)
prompt_tokens = len(response.prompt_token_ids)
if first_iteration:
return self.yield_first_chat(request, request_id, response)
# TODO: Fix this
if request.echo:
last_msg_content = ""
if (
conversation
and conversation[-1].get("content")
and conversation[-1].get("role") == role
):
last_msg_content = conversation[-1]["content"]
if last_msg_content:
return self.yield_first_chat(
request, request_id, response, content=last_msg_content
)
first_iteration = False
for output in response.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_text = output.text_diff
if (
request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
):
delta_message = DeltaMessage(
tool_calls=[
ToolCall(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text,
)
)
]
)
else:
delta_message = DeltaMessage(content=delta_text, role=role)
choice = DynamoTRTLLMChatCompletionResponseStreamChoice(
index=i, delta=delta_message, finish_reason=None
)
if request.logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = self._create_logprobs(token_ids, logprobs)
if output.finish_reason is not None:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
finish_reason_sent[i] = True
if output.disaggregated_params is not None:
choice.disaggregated_params = (
DisaggregatedTypeConverter.to_oai_disaggregated_params(
output.disaggregated_params
)
)
chunk = DynamoTRTLLMChatCompletionStreamResponse(
id=request_id,
choices=[choice],
model=self.model,
)
logger.debug(f"[processor] Chunk: {chunk}")
chunk.usage = self._stream_usage_info(request, prompt_tokens, output.length)
return chunk.model_dump_json()
# TODO: make request.stream_options.include_usage = True when stream=False in rust
if request.stream_options and request.stream_options.include_usage:
completion_tokens = sum(output.length for output in response.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = DynamoTRTLLMChatCompletionStreamResponse(
id=request_id,
choices=[],
model=self.model,
usage=final_usage,
)
return final_usage_chunk.model_dump_json()
return "data: [DONE]\n\n"
async def preprocess(self, request):
conversation: List[Any] = []
for message in request.messages:
conversation.extend(parse_chat_message_content(message))
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
prompt: str = self.tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
sampling_params = request.to_sampling_params()
return TRTLLMWorkerRequest(
id=request.id,
prompt=prompt,
sampling_params=asdict(sampling_params),
conversation=conversation,
disaggregated_params=request.disaggregated_params,
# NOTE: dont include the first token (e.g. <s>) when searching for a prefix match. We might want to exclude all special tokens at some point.
tokens=Tokens(tokens=self.tokenizer.encode(prompt)[1:]),
)
async def postprocess(
self,
engine_generator,
request,
conversation,
server_type: ServerType,
):
async for raw_response in engine_generator:
if self.using_engine_generator:
response = TRTLLMWorkerResponse(
request_id=request.id,
prompt=raw_response.prompt,
prompt_token_ids=raw_response.prompt_token_ids,
outputs=[asdict(raw_response.outputs[0])],
finished=raw_response.finished,
)
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
else:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
if (
request.disaggregated_params is not None
and server_type == ServerType.CTX
):
response_data = self.yield_first_chat(request, request.id, response)
else:
response_data = self.create_chat_stream_response(
request,
request.id,
response,
conversation,
first_iteration=(not request.disaggregated_params is not None),
)
logger.debug(f"[postprocessor] Response: {response_data}")
yield response_data
class CompletionsProcessor:
def __init__(self, model: str, tokenizer: AutoTokenizer):
self.model = model
self.tokenizer = tokenizer
def create_completion_stream_response(self, request, response):
num_choices = 1 if request.n is None else request.n
echoed = [False] * num_choices
# len(response.outputs) is always 1
for gen_idx, output in enumerate(response.outputs):
delta_text = output.text_diff
if request.echo and not echoed[gen_idx]:
delta_text = request.prompt + delta_text
echoed[gen_idx] = True
choice = DynamoTRTLLMCompletionResponseStreamChoice(
index=gen_idx,
text=delta_text,
stop_reason=output.stop_reason,
finish_reason=output.finish_reason,
)
if output.disaggregated_params is not None:
choice.disaggregated_params = (
DisaggregatedTypeConverter.to_oai_disaggregated_params(
output.disaggregated_params
)
)
chunk = DynamoTRTLLMCompletionStreamResponse(
model=self.model,
choices=[choice],
)
return chunk.model_dump_json()
async def preprocess(self, request):
if isinstance(request.prompt, str) or (
isinstance(request.prompt, list)
and all(isinstance(x, int) for x in request.prompt)
):
prompt = request.prompt
else:
raise ValueError(
"Invalid prompt type. Only string or list of integers are supported."
)
sampling_params = request.to_sampling_params()
return TRTLLMWorkerRequest(
id=request.id,
prompt=prompt,
sampling_params=asdict(sampling_params),
disaggregated_params=request.disaggregated_params,
tokens=Tokens(tokens=self.tokenizer.encode(prompt)[1:]),
)
async def postprocess(
self,
engine_generator,
request,
):
async for raw_response in engine_generator:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
response_data = self.create_completion_stream_response(
request,
response,
)
logger.debug(f"[postprocessor] Response: {response_data}")
yield response_data
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
from ctypes import c_char_p, c_int64, c_uint32
from tensorrt_llm.logger import logger
logger.set_level("info")
class DynamoResult:
OK = 0
ERR = 1
class KVCacheEventPublisher:
def __init__(
self,
namespace: str,
component: str,
worker_id: int,
lib_path: str,
kv_block_size: int,
):
self.lib = None
try:
self.lib = ctypes.CDLL(lib_path)
self.lib.dynamo_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
self.lib.dynamo_llm_init.restype = c_uint32
result = self.lib.dynamo_llm_init(
namespace.encode(), component.encode(), worker_id, kv_block_size
)
if result == DynamoResult.OK:
logger.info(
"KVCacheEventPublisher initialized successfully. Ready to publish KV Cache Events"
)
else:
logger.info("KVCacheEventPublisher initialization failed!")
except Exception as e:
print(f"Failed to load {lib_path}")
raise e
self.lib.dynamo_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
ctypes.POINTER(ctypes.c_uint64), # parent_hash
ctypes.c_uint64, # lora_id
]
self.lib.dynamo_kv_event_publish_stored.restype = (
ctypes.c_uint32
) # dynamo_llm_result_t
self.lib.dynamo_kv_event_publish_removed.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
]
self.lib.dynamo_kv_event_publish_removed.restype = (
ctypes.c_uint32
) # dynamo_llm_result_t
self._event_counter = 0
def stored_event(self, parent_hash, block_hash, token_ids, lora_id):
if self.lib is None:
logger.error("KVCacheEventPublisher not initialized!")
return
logger.debug(
f"Stored parent_hash: {parent_hash}, block_hash: {block_hash}, token_ids: {token_ids}"
)
parent_hash = (
(ctypes.c_uint64 * 1)(parent_hash) if parent_hash is not None else None
)
token_ids_arr = (ctypes.c_uint32 * len(token_ids))(*token_ids)
num_block_tokens = (ctypes.c_size_t * 1)(len(token_ids))
block_hash = (ctypes.c_uint64 * 1)(block_hash)
result = self.lib.dynamo_kv_event_publish_stored(
self._event_counter, # uint64_t event_id
token_ids_arr, # const uint32_t *token_ids
num_block_tokens, # const uintptr_t *num_block_tokens
block_hash, # const uint64_t *block_ids
1, # uintptr_t num_blocks
parent_hash, # const uint64_t *parent_hash
lora_id, # uint64_t lora_id
)
self._event_counter += 1
if result == DynamoResult.OK:
logger.debug(f"Store - Published KV Event: {block_hash}")
else:
logger.error(f"Store - Failed to Publish KV Event: {block_hash}")
def removed_event(self, block_hash):
if self.lib is None:
logger.error("KVCacheEventPublisher not initialized!")
return
result = self.lib.dynamo_kv_event_publish_removed(
self._event_counter,
(ctypes.c_uint64 * 1)(block_hash),
1,
)
self._event_counter += 1
if result == DynamoResult.OK:
logger.debug(f"Remove - Published KV Event: {block_hash}")
else:
logger.error(f"Remove - Failed to Publish KV Event: {block_hash}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Tuple
import yaml
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import KvCacheConfig
@dataclass
class LLMAPIConfig:
def __init__(
self,
model_name: str,
model_path: str | None = None,
pytorch_backend_config: PyTorchConfig | None = None,
kv_cache_config: KvCacheConfig | None = None,
**kwargs,
):
self.model_name = model_name
self.model_path = model_path
self.pytorch_backend_config = pytorch_backend_config
self.kv_cache_config = kv_cache_config
self.extra_args = kwargs
def to_dict(self) -> Dict[str, Any]:
data = {
"pytorch_backend_config": self.pytorch_backend_config,
"kv_cache_config": self.kv_cache_config,
}
if self.extra_args:
data.update(self.extra_args)
return data
def update_sub_configs(self, other_config: Dict[str, Any]):
if "pytorch_backend_config" in other_config:
self.pytorch_backend_config = PyTorchConfig(
**other_config["pytorch_backend_config"]
)
self.extra_args.pop("pytorch_backend_config", None)
if "kv_cache_config" in other_config:
self.kv_cache_config = KvCacheConfig(**other_config["kv_cache_config"])
self.extra_args.pop("kv_cache_config", None)
def _get_llm_args(engine_config):
# Only do model validation checks and leave other checks to LLMAPI
if "model_name" not in engine_config:
raise ValueError("Model name is required in the TRT-LLM engine config.")
if engine_config.get("model_path", ""):
if os.path.exists(engine_config.get("model_path", "")):
engine_config["model_path"] = Path(engine_config["model_path"])
else:
raise ValueError(f"Model path {engine_config['model_path']} does not exist")
model_name = engine_config["model_name"]
model_path = engine_config.get("model_path", None)
engine_config.pop("model_name")
engine_config.pop("model_path", None)
# Store all other args as kwargs
llm_api_config = LLMAPIConfig(
model_name=model_name,
model_path=model_path,
**engine_config,
)
# Parse supported sub configs and remove from kwargs
llm_api_config.update_sub_configs(engine_config)
return llm_api_config
def _init_engine_args(engine_args_filepath):
"""Initialize engine arguments from config file."""
if not os.path.isfile(engine_args_filepath):
raise ValueError(
"'YAML file containing TRT-LLM engine args must be provided in when launching the worker."
)
try:
with open(engine_args_filepath) as file:
trtllm_engine_config = yaml.safe_load(file)
except yaml.YAMLError as e:
raise RuntimeError(f"Failed to parse engine config: {e}")
return _get_llm_args(trtllm_engine_config)
def parse_tensorrt_llm_args(
config_args,
) -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]]:
parser = argparse.ArgumentParser(description="A TensorRT-LLM Worker parser")
parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file"
)
parser.add_argument(
"--llmapi-disaggregated-config",
"-c",
type=str,
help="Path to the llmapi disaggregated config file",
default=None,
)
parser.add_argument(
"--router",
type=str,
choices=["random", "round-robin", "kv"],
default="random",
help="Router type to use for scheduling requests to workers",
)
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers for aggregated (monolith) server",
)
parser.add_argument(
"--block-size",
type=int,
default=32,
help="Number of tokens per KV block in TRTLLM worker. Default is 32 for pytorch backend.",
)
parser.add_argument(
"--remote-prefill",
action="store_true",
help="Use remote prefill workers for generation server in Disaggregated mode.",
)
args = parser.parse_args(config_args)
return (args, _init_engine_args(args.engine_args))
def parse_dynamo_run_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]]]:
parser = argparse.ArgumentParser(
description="A TensorRT-LLM Dynamo-run engine parser"
)
parser.add_argument(
"--engine_args", type=str, required=True, help="Path to the engine args file"
)
# Disaggregated mode is not supported in dynamo-run launcher yet.
# parser.add_argument(
# "--llmapi-disaggregated-config",
# "-c",
# type=str,
# help="Path to the llmapi disaggregated config file",
# default=None,
# )
parser.add_argument(
"--publish-kv-cache-events",
action="store_true",
help="Publish KV cache events from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.",
)
args, _ = parser.parse_known_args()
return (args, _init_engine_args(args.engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, List, Literal, Optional, Union
import torch
from common.utils import ConversationMessage
from pydantic import BaseModel, ConfigDict, Field
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponseStreamChoice,
CompletionRequest,
CompletionResponseStreamChoice,
DisaggregatedParams,
UsageInfo,
)
# The max_tokens is being deprecated in favor of max_completion_tokens.
# However, TRTLLM protocol might still refer it as max_tokens.
class DynamoTRTLLMCompletionRequest(CompletionRequest):
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid4().hex)}")
max_completion_tokens: Optional[int] = None
class DynamoTRTLLMChatCompletionRequest(ChatCompletionRequest):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid4().hex)}")
max_completion_tokens: Optional[int] = None
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
class Tokens(BaseModel):
tokens: list[int]
class Request(BaseModel):
prompt: str
sampling_params: dict
streaming: bool
class TRTLLMWorkerRequest(BaseModel):
id: str
prompt: str | None = None
sampling_params: dict
streaming: bool = True
conversation: Optional[List[ConversationMessage]] = Field(default=None)
tokens: Optional[Tokens] = Field(default=None)
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
def to_sampling_params(self) -> SamplingParams:
sampling_params = SamplingParams(
frequency_penalty=self.sampling_params.get("frequency_penalty", 0.0),
return_log_probs=self.sampling_params.get("logprobs", False),
max_tokens=self.sampling_params.get("max_tokens", 16),
n=self.sampling_params.get("n", 1),
presence_penalty=self.sampling_params.get("presence_penalty", 0.0),
seed=self.sampling_params.get("seed", None),
stop=self.sampling_params.get("stop", None),
temperature=self.sampling_params.get("temperature", 0.7),
# chat-completion-sampling-params
best_of=self.sampling_params.get("best_of", None),
use_beam_search=self.sampling_params.get("use_beam_search", False),
top_k=self.sampling_params.get("top_k", 0),
top_p=self.sampling_params.get("top_p", 1.0),
top_p_min=self.sampling_params.get("top_p_min", None),
min_p=self.sampling_params.get("min_p", 0.0),
repetition_penalty=self.sampling_params.get("repetition_penalty", 1.0),
length_penalty=self.sampling_params.get("length_penalty", 1.0),
early_stopping=self.sampling_params.get("early_stopping", False),
stop_token_ids=self.sampling_params.get("stop_token_ids", []),
include_stop_str_in_output=self.sampling_params.get(
"include_stop_str_in_output", False
),
ignore_eos=self.sampling_params.get("ignore_eos", False),
min_tokens=self.sampling_params.get("min_tokens", 0),
skip_special_tokens=self.sampling_params.get("skip_special_tokens", False),
spaces_between_special_tokens=self.sampling_params.get(
"spaces_between_special_tokens", False
),
truncate_prompt_tokens=self.sampling_params.get(
"truncate_prompt_tokens", None
),
# chat-completion-extra-params
add_special_tokens=self.sampling_params.get("add_special_tokens", False),
)
return sampling_params
@dataclass
class TRTLLMWorkerResponseOutput:
index: int
text: str
token_ids: list[int]
logprobs: Optional[List[float]] = None
cumulative_logprob: Optional[float] = None
finish_reason: Optional[Literal["stop", "length", "timeout", "cancelled"]] = None
stop_reason: Optional[Union[int, str]] = None
generation_logits: Optional[torch.Tensor] = None
disaggregated_params: Optional[DisaggregatedParams] = None
_last_text_len: int = field(default=0)
_last_token_ids_len: int = field(default=0)
_last_logprobs_len: int = field(default=0)
_incremental_states: Optional[dict] = field(default=None)
_postprocess_result: Optional[Any] = field(default=None)
text_diff: str = field(default="")
length: int = field(default=0)
def __post_init__(self):
self.text_diff = self.text[self._last_text_len :]
self.length = len(self.token_ids)
class TRTLLMWorkerResponse(BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
request_id: str
prompt: str | None = None
prompt_token_ids: list[int]
outputs: list[dict]
finished: bool
# TODO
# prompt_logprobs: list[float]
class DisaggregatedTypeConverter:
@staticmethod
def to_llm_disaggregated_params(
disaggregated_params: DisaggregatedParams,
) -> LlmDisaggregatedParams:
if disaggregated_params is None:
return None
else:
opaque_state = (
base64.b64decode(disaggregated_params.encoded_opaque_state)
if disaggregated_params.encoded_opaque_state is not None
else None
)
return LlmDisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=opaque_state,
)
@staticmethod
def to_oai_disaggregated_params(
tllm_disagg_params: LlmDisaggregatedParams,
) -> DisaggregatedParams:
if tllm_disagg_params is None:
return None
else:
encoded_opaque_state = (
base64.b64encode(tllm_disagg_params.opaque_state).decode("utf-8")
if tllm_disagg_params is not None
else None
)
return DisaggregatedParams(
request_type=tllm_disagg_params.request_type,
first_gen_tokens=tllm_disagg_params.first_gen_tokens,
ctx_request_id=tllm_disagg_params.ctx_request_id,
encoded_opaque_state=encoded_opaque_state,
)
# Chat Completions
class DynamoTRTLLMChatCompletionResponseStreamChoice(
ChatCompletionResponseStreamChoice
):
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
class DynamoTRTLLMChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid4().hex)}")
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[DynamoTRTLLMChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
## Completions
class DynamoTRTLLMCompletionResponseStreamChoice(CompletionResponseStreamChoice):
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
class DynamoTRTLLMCompletionStreamResponse(BaseModel):
model_config = ConfigDict(extra="forbid")
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid4().hex)}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[DynamoTRTLLMCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import threading
import traceback
import weakref
from enum import Enum
from queue import Queue
from typing import Callable, Optional, TypedDict, Union
from tensorrt_llm.logger import logger
logger.set_level("info")
class RoutingStrategy(Enum):
ROUND_ROBIN = "round_robin"
RANDOM = "random"
PREFIX = "prefix"
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class ServerType(Enum):
# Generation server used for disaggregated and aggregated requests
GEN = "gen"
# Context server used for disaggregated requests
CTX = "ctx"
class ConversationMessage(TypedDict):
role: str
content: str
class ManagedThread(threading.Thread):
def __init__(
self,
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]],
error_queue: Optional[Queue] = None,
name: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs,
):
super().__init__(name=name)
self.task = task
self.error_queue = error_queue
self.kwargs = kwargs
self.loop = loop
self.daemon = True
self.stop_event = threading.Event()
def set_loop(self, loop: asyncio.AbstractEventLoop):
self.loop = loop
def run(self):
while not self.stop_event.is_set():
task: Optional[Union[Callable[..., bool], weakref.WeakMethod]] = self.task
if isinstance(task, weakref.WeakMethod):
task = task()
if task is None:
# Normally, this should not happen.
logger.warning("WeakMethod is expired.")
break
if task is None:
break
try:
if self.loop is None:
logger.error("[ManagedThread] Loop not initialized!")
break
future = asyncio.run_coroutine_threadsafe(
task(**self.kwargs), self.loop
)
_ = future.result()
except Exception as e:
logger.error(
f"Error in thread {self.name}: {e}\n{traceback.format_exc()}"
)
if self.error_queue is not None:
self.error_queue.put(e)
logger.info(f"Thread {self.name} stopped.")
def stop(self):
self.stop_event.set()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import signal
from dataclasses import asdict
from common.base_engine import BaseTensorrtLLMEngine, TensorrtLLMEngineConfig
from common.parser import parse_tensorrt_llm_args
from common.protocol import TRTLLMWorkerRequest, TRTLLMWorkerResponse
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.logger import logger
from dynamo.llm import KvMetricsPublisher
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger.set_level("debug")
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class TensorRTLLMWorker(BaseTensorrtLLMEngine):
"""
Request handler for the generate endpoint
"""
def __init__(self):
print("Initializing TensorRT-LLM Worker")
class_name = self.__class__.__name__
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
self.args, self.engine_config = parse_tensorrt_llm_args(config_args)
if self.args.router == "kv":
publish_stats = True
publish_events = True
else:
publish_stats = False
publish_events = False
trt_llm_engine_config = TensorrtLLMEngineConfig(
namespace_str="dynamo",
component_str=class_name,
engine_config=self.engine_config,
publish_stats=publish_stats,
publish_kv_cache_events=publish_events,
kv_block_size=self.args.block_size,
)
if publish_stats:
trt_llm_engine_config.kv_metrics_publisher = KvMetricsPublisher()
trt_llm_engine_config.worker_id = dynamo_context["endpoints"][0].lease_id()
self.trtllm_engine_args = trt_llm_engine_config
@async_on_start
async def async_init(self):
super().__init__(self.trtllm_engine_args)
print("TensorRT-LLM Worker initialized")
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
await self.metrics_publisher.create_endpoint(component)
@dynamo_endpoint()
async def generate(self, request: TRTLLMWorkerRequest):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
if self._error_queue.qsize() > 0:
error = self._error_queue.get()
raise error
self._ongoing_request_count += 1
try:
# TODO: combine with disagg worker
# TODO: only send tokens. Should be pretty simple.
async for response in self._llm_engine.generate_async(
inputs=request.prompt,
sampling_params=request.to_sampling_params(),
disaggregated_params=None,
streaming=True,
):
yield TRTLLMWorkerResponse(
request_id=request.id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
outputs=[asdict(response.outputs[0])],
finished=response.finished,
).model_dump_json(exclude_unset=True)
except CppExecutorError:
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
self._start_threads()
self._ongoing_request_count -= 1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
from pathlib import Path
from components.agg_worker import TensorRTLLMWorker
from components.processor import Processor
from pydantic import BaseModel
from dynamo import sdk
from dynamo.sdk import depends, service
from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE
def get_http_binary_path():
sdk_path = Path(sdk.__file__)
binary_path = sdk_path.parent / "cli/bin/http"
if not binary_path.exists():
return "http"
else:
return str(binary_path)
class FrontendConfig(BaseModel):
served_model_name: str
endpoint: str
port: int = 8080
@service(
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
image=DYNAMO_IMAGE,
)
# todo this should be called ApiServer
class Frontend:
worker = depends(TensorRTLLMWorker)
processor = depends(Processor)
def __init__(self):
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
subprocess.run(
[
"llmctl",
"http",
"remove",
"chat-models",
frontend_config.served_model_name,
]
)
subprocess.run(
[
"llmctl",
"http",
"add",
"chat-models",
frontend_config.served_model_name,
frontend_config.endpoint,
]
)
print("Starting HTTP server")
http_binary = get_http_binary_path()
process = subprocess.Popen(
[http_binary, "-p", str(frontend_config.port)], stdout=None, stderr=None
)
try:
process.wait()
except KeyboardInterrupt:
process.terminate()
process.wait()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import random
import traceback
from argparse import Namespace
from typing import AsyncIterator
from common.protocol import Tokens
from components.agg_worker import TensorRTLLMWorker
from tensorrt_llm.logger import logger
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger.set_level("debug")
WorkerId = str
def parse_args(service_name, prefix) -> Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
parser.add_argument(
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
# TODO: Read block size
parser.add_argument(
"--block-size",
type=int,
default=64,
help="KV block size",
)
parser.add_argument(
"--custom-router",
type=bool,
default=False,
help="Whether to use custom router or not",
)
config = ServiceConfig.get_instance()
config_args = config.as_args(service_name, prefix=prefix)
args = parser.parse_args(config_args)
return args
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Router:
worker = depends(TensorRTLLMWorker)
def __init__(self):
logger.info("Initializing KV router.")
class_name = self.__class__.__name__
self.args = parse_args(class_name, "")
@async_on_start
async def async_init(self):
self.runtime = dynamo_context["runtime"]
self.workers_client = (
await self.runtime.namespace("dynamo")
.component("TensorRTLLMWorker")
.endpoint("generate")
.client()
)
while len(self.workers_client.endpoint_ids()) < self.args.min_workers:
# TODO: replace print w/ vllm_logger.info
print(
f"Waiting for more workers to be ready.\n"
f" Current: {len(self.workers_client.endpoint_ids())},"
f" Required: {self.args.min_workers}"
)
await asyncio.sleep(2)
kv_listener = self.runtime.namespace("dynamo").component("TensorRTLLMWorker")
await kv_listener.create_service()
self.indexer = KvIndexer(kv_listener, self.args.block_size)
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
print("KV Router initialized")
def _cost_function(
self,
scores: OverlapScores | None,
metrics: AggregatedMetrics | None,
token_length: int,
):
worker_scores = {}
if scores:
for worker_id, score in scores.scores.items():
# score is number of matching blocks we multiply by block_size to get tokens
# and compare to token_length. The larger the cache hit the better
worker_scores[worker_id] = (
score * self.indexer.block_size() / token_length
)
logger.debug(f"Worker scores: {worker_scores}")
worker_metrics = {}
# pull metrics for each worker
max_waiting = 0.0
if metrics:
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
worker_metrics[worker_id] = {
"gpu_cache_usage_perc": endpoint.gpu_cache_usage_perc
if hasattr(endpoint, "gpu_cache_usage_perc")
else 0.0,
"num_requests_waiting": endpoint.num_requests_waiting
if hasattr(endpoint, "num_requests_waiting")
else 0.0,
"gpu_prefix_cache_hit_rate": endpoint.gpu_prefix_cache_hit_rate
if hasattr(endpoint, "gpu_prefix_cache_hit_rate")
else 0.0,
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
)
logger.debug(f"Worker metrics: {worker_metrics}")
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.endpoint_ids()
worker_logits = {}
for worker_id in worker_ids:
# Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(
worker_id,
{
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
},
)
normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting
if max_waiting > 0
else 0.0
)
# Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = (
2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
)
logger.debug(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
)
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
return ""
# Select the worker with the highest logit
if worker_logits:
max_logit = max(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
else:
best_worker_id = ""
# Log the metrics for the selected worker
if best_worker_id:
logger.debug(
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
)
logger.debug(
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
)
metrics_dict = worker_metrics.get(best_worker_id, {})
logger.debug(
f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
)
logger.debug(
f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
)
logger.debug(
f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
)
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
@dynamo_endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
if self.indexer is None or self.metrics_aggregator is None:
yield "_0.0"
lora_id = 0
worker_id = ""
try:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
)
token_length = len(request.tokens)
metrics = await self.metrics_aggregator.get_metrics()
schedule_result = self._cost_function(scores, metrics, token_length)
except Exception:
schedule_result = ""
logger.warning(f"Error during worker selection: {traceback.format_exc()}")
if schedule_result == "":
worker_id = ""
prefix_hit_rate = 0.0
else:
worker_id, prefix_hit_rate = schedule_result
yield f"{worker_id}_{prefix_hit_rate}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
from common.base_engine import ChatProcessorMixin
from common.parser import parse_tensorrt_llm_args
from common.protocol import DynamoTRTLLMChatCompletionRequest
from common.utils import RequestType, ServerType
from components.agg_worker import TensorRTLLMWorker
from components.kv_router import Router
from tensorrt_llm.logger import logger
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
logger.set_level("debug")
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ChatProcessorMixin):
worker = depends(TensorRTLLMWorker)
router = depends(Router)
def __init__(
self,
):
class_name = self.__class__.__name__
config = ServiceConfig.get_instance()
config_args = config.as_args(class_name, prefix="")
self.args, self.engine_config = parse_tensorrt_llm_args(config_args)
self.router_mode = self.args.router
super().__init__(self.engine_config)
self.min_workers = 1
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = TensorRTLLMWorker.dynamo_address() # type: ignore
self.worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
while len(self.worker_client.endpoint_ids()) < self.min_workers:
print(
f"Waiting for workers to be ready.\n"
f" Current: {len(self.worker_client.endpoint_ids())},"
f" Required: {self.min_workers}"
)
await asyncio.sleep(2)
async def _generate(self, raw_request, request_type: RequestType):
raw_request.skip_special_tokens = False
raw_request.add_special_tokens = False
raw_request.spaces_between_special_tokens = False
logger.debug(f"[preprocessor] Received request: {raw_request}")
if request_type == RequestType.CHAT:
preprocessed_request = await self.chat_processor.preprocess(raw_request)
else:
preprocessed_request = await self.completions_processor.preprocess(
raw_request
)
worker_id = ""
if self.router_mode == "kv":
async for route_response in self.router.generate(
preprocessed_request.tokens.model_dump_json()
):
worker_id, prefix_hit_rate = route_response.split("_")
prefix_hit_rate = float(prefix_hit_rate)
logger.info(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
break
if worker_id == "":
if self.args.router == "round-robin":
engine_generator = await self.worker_client.round_robin(
preprocessed_request.model_dump_json()
)
else:
# fallback to random
engine_generator = await self.worker_client.random(
preprocessed_request.model_dump_json()
)
else:
engine_generator = await self.worker_client.direct(
preprocessed_request.model_dump_json(), int(worker_id)
)
if request_type == RequestType.CHAT:
async for response in self.chat_processor.postprocess(
engine_generator,
raw_request,
preprocessed_request.conversation,
ServerType.GEN,
):
logger.debug(f"[preprocessor] Response: {response}")
yield json.loads(response)
else:
async for response in self.completions_processor.postprocess(
engine_generator, raw_request
):
logger.debug(f"[preprocessor] Response: {response}")
yield json.loads(response)
@dynamo_endpoint(name="chat/completions")
async def generate_chat(self, raw_request: DynamoTRTLLMChatCompletionRequest):
async for response in self._generate(raw_request, RequestType.CHAT):
yield response
# @dynamo_endpoint()
# async def completions(self, raw_request):
# async for response in self._generate(raw_request, RequestType.COMPLETION):
# yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000
Processor:
engine_args: "configs/llm_api_config.yaml"
block-size: 64
router: round-robin
TensorRTLLMWorker:
engine_args: "configs/llm_api_config.yaml"
router: random
ServiceArgs:
workers: 1
resources:
gpu: 1
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000
Processor:
engine_args: "configs/llm_api_config.yaml"
block-size: 64
router: kv
TensorRTLLMWorker:
engine_args: "configs/llm_api_config.yaml"
router: kv
ServiceArgs:
workers: 1
resources:
gpu: 1
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# In the case of disaggregated deployment, this config will apply to each server
# and will be overwritten by the disaggregated config file
model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path: null
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 10240
max_batch_size: 16
trust_remote_code: true
backend: pytorch
kv_cache_config:
free_gpu_memory_fraction: 0.95
# Uncomment to enable kv cache event collection
#event_buffer_max_size: 1024
#enable_block_reuse: true
pytorch_backend_config:
enable_overlap_scheduler: false
use_cuda_graph: false
# Uncomment to enable iter perf stats
#enable_iter_perf_stats: true
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment