"launch/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "3d2928510c77fc6cb29c497c91a493e3d06c0cc1"
Commit 2791b9ea authored by NVShreyas's avatar NVShreyas Committed by GitHub
Browse files

feat: OAI compatible endpoints for TRTLLM (#14)


Co-authored-by: default avatarTanmay Verma <tanmayv@nvidia.com>
Co-authored-by: default avatarTanmay Verma <tanmayv@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarTanmay Verma <tanmay2592@gmail.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent a657ec61
...@@ -26,9 +26,11 @@ pre-commit ...@@ -26,9 +26,11 @@ pre-commit
protobuf==5.27.3 protobuf==5.27.3
pydantic==2.7.1 pydantic==2.7.1
pyright pyright
PyYAML
sentencepiece sentencepiece
transformers transformers
tritonclient==2.53.0 tritonclient==2.53.0
types-PyYAML
# TODO: See whether TRT-LLM installs a different version of UCX. Need to revisit and track this dependency. # TODO: See whether TRT-LLM installs a different version of UCX. Need to revisit and track this dependency.
ucx-py-cu12 ucx-py-cu12
uvicorn uvicorn
...@@ -84,20 +84,32 @@ pip install /home/tensorrt_llm-*.whl ...@@ -84,20 +84,32 @@ pip install /home/tensorrt_llm-*.whl
Note: NATS and ETCD servers should be running and accessible from the container as described in the [Prerequisites](#prerequisites) section. Note: NATS and ETCD servers should be running and accessible from the container as described in the [Prerequisites](#prerequisites) section.
### 1. Monolithic Deployment ### Monolithic Deployment
Run the server and client components in separate terminal sessions: #### 1. HTTP Server
**Server:** Run the server logging (with debug level logging):
```bash
TRD_LOG=DEBUG http &
```
By default the server will run on port 8080.
Add model to the server:
```bash
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0 triton-init.tensorrt-llm.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0 triton-init.tensorrt-llm.completions
```
#### 2. Workers
Note: The following commands are tested on machines withH100x8 GPUs Note: The following commands are tested on machines withH100x8 GPUs
#### Option 1.1 Single-Node Single-GPU ##### Option 2.1 Single-Node Single-GPU
```bash ```bash
# Launch worker # Launch worker
cd /workspace/examples/python_rs/llm/tensorrt_llm cd /workspace/examples/python_rs/llm/tensorrt_llm
mpirun --allow-run-as-root -n 1 --oversubscribe python3 -m monolith.worker --engine_args model.json mpirun --allow-run-as-root -n 1 --oversubscribe python3 -m monolith.worker --engine_args llm_api_config.yaml 1>agg_worker.log 2>&1 &
``` ```
Upon successful launch, the output should look similar to: Upon successful launch, the output should look similar to:
...@@ -113,66 +125,114 @@ Upon successful launch, the output should look similar to: ...@@ -113,66 +125,114 @@ Upon successful launch, the output should look similar to:
`nvidia-smi` can be used to check the GPU usage and the model is loaded on single GPU. `nvidia-smi` can be used to check the GPU usage and the model is loaded on single GPU.
#### Option 1.2 Single-Node Multi-GPU ##### Option 2.2 Single-Node Multi-GPU
Update `tensor_parallel_size` in the `model.json` to load the model with the desired number of GPUs.
For this example, we will load the model with 4 GPUs.
```bash Update `tensor_parallel_size` in the `llm_api_config.yaml` to load the model with the desired number of GPUs.
# Launch worker
cd /workspace/examples/python_rs/llm/tensorrt_llm
mpirun --allow-run-as-root -n 1 --oversubscribe python3 -m monolith.worker --engine_args model.json
```
`nvidia-smi` can be used to check the GPU usage and the model is loaded on 4 GPUs. `nvidia-smi` can be used to check the GPU usage and the model is loaded on 4 GPUs.
#### Option 1.3 Multi-Node Multi-GPU ##### Option 2.3 Multi-Node Multi-GPU
Tanmay[WIP] TODO: Add multi-node multi-GPU example
**Client:** #### 3. Client
```bash ```bash
# Chat Completion
# Run client curl localhost:8080/v1/chat/completions \
python3 -m common.client \ -H "Content-Type: application/json" \
--prompt "Describe the capital of France" \ -d '{
--max-tokens 10 \ "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
--temperature 0.5 \ "messages": [
--component tensorrt-llm {"role": "user", "content": "What is the capital of France?"}
]
}'
``` ```
The output should look similar to: The output should look similar to:
```json
{
"id": "ab013077-8fb2-433e-bd7d-88133fccd497",
"choices": [
{
"message": {
"role": "assistant",
"content": "The capital of France is Paris."
},
"index": 0,
"finish_reason": "stop"
}
],
"created": 1740617803,
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "chat.completion",
"usage": null,
"system_fingerprint": null
}
``` ```
Annotated(data=',', event=None, comment=[], id=None)
Annotated(data=', Paris', event=None, comment=[], id=None) ```bash
Annotated(data=', Paris,', event=None, comment=[], id=None) # Completion
Annotated(data=', Paris, in', event=None, comment=[], id=None) curl localhost:8080/v1/completions \
Annotated(data=', Paris, in terms', event=None, comment=[], id=None) -H "Content-Type: application/json" \
Annotated(data=', Paris, in terms of', event=None, comment=[], id=None) -d '{
Annotated(data=', Paris, in terms of its', event=None, comment=[], id=None) "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
Annotated(data=', Paris, in terms of its history', event=None, comment=[], id=None) "prompt": "The capital of France is",
Annotated(data=', Paris, in terms of its history,', event=None, comment=[], id=None) "max_tokens": 1,
Annotated(data=', Paris, in terms of its history, culture', event=None, comment=[], id=None) "temperature": 0
}'
``` ```
### 2. Disaggregated Deployment Output:
```json
{
"id":"cmpl-e0d75aca1bd540399809c9b609eaf010",
"choices":[
{
"text":"Paris",
"index":0,
"finish_reason":"length"
}
],
"created":1741024639,
"model":"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object":"text_completion",
"usage":null
}
```
#### 2.1 Single-Node Disaggregated Deployment ### Disaggregated Deployment
**Environment** **Environment**
This is the latest image with tensorrt_llm supporting distributed serving with pytorch workflow in LLM API. This is the latest image with tensorrt_llm supporting distributed serving with pytorch workflow in LLM API.
Run the container interactively with the following command: Run the container interactively with the following command:
```bash ```bash
./container/run.sh --image IMAGE -it ./container/run.sh --image IMAGE -it
``` ```
#### 1. HTTP Server
Run the server logging (with debug level logging):
```bash
TRD_LOG=DEBUG http &
```
By default the server will run on port 8080.
Add model to the server:
```bash
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0 triton-init.router.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0 triton-init.router.completions
```
#### 2. Workers
##### Option 2.1 Single-Node Disaggregated Deployment
**TRTLLM LLMAPI Disaggregated config file** **TRTLLM LLMAPI Disaggregated config file**
Define disaggregated config file similar to the example [single_node_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_config.yaml). The important sections are the model, context_servers and generation_servers. Define disaggregated config file similar to the example [single_node_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_config.yaml). The important sections are the model, context_servers and generation_servers.
**Launch the servers** 1. **Launch the servers**
Launch context and generation servers.\ Launch context and generation servers.\
WORLD_SIZE is the total number of workers covering all the servers described in disaggregated configuration.\ WORLD_SIZE is the total number of workers covering all the servers described in disaggregated configuration.\
...@@ -180,32 +240,25 @@ For example, 2 TP2 generation servers are 2 servers but 4 workers/mpi executor. ...@@ -180,32 +240,25 @@ For example, 2 TP2 generation servers are 2 servers but 4 workers/mpi executor.
```bash ```bash
cd /workspace/examples/python_rs/llm/tensorrt_llm/ cd /workspace/examples/python_rs/llm/tensorrt_llm/
mpirun --allow-run-as-root --oversubscribe -n WORLD_SIZE python3 -m disaggregated.worker --engine_args model.json -c disaggregated/llmapi_disaggregated_configs/single_node_config.yaml & mpirun --allow-run-as-root --oversubscribe -n WORLD_SIZE python3 -m disaggregated.worker --engine_args llm_api_config.yaml -c disaggregated/llmapi_disaggregated_configs/single_node_config.yaml 1>disagg_workers.log 2>&1 &
``` ```
If using the provided [single_node_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_config.yaml), WORLD_SIZE should be 3 as it has 2 context servers(TP=1) and 1 generation server(TP=1). If using the provided [single_node_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_config.yaml), WORLD_SIZE should be 3 as it has 2 context servers(TP=1) and 1 generation server(TP=1).
**Launch the router** 2. **Launch the router**
```bash ```bash
cd /workspace/examples/python_rs/llm/tensorrt_llm/ cd /workspace/examples/python_rs/llm/tensorrt_llm/
python3 -m disaggregated.router -c disaggregated/llmapi_disaggregated_configs/single_node_config.yaml & python3 -m disaggregated.router 1>router.log 2>&1 &
``` ```
**Send Requests** 3. **Send Requests**
Follow the instructions in the [Monolithic Deployment](#3-client) section to send requests to the router.
```bash
cd /workspace/examples/python_rs/llm/tensorrt_llm/
python3 -m common.client \
--prompt "Describe the capital of France" \
--max-tokens 10 \
--temperature 0.5 \
--component router
```
For more details on the disaggregated deployment, please refer to the [TRT-LLM example](#TODO). For more details on the disaggregated deployment, please refer to the [TRT-LLM example](#TODO).
### 3. Multi-Node Disaggregated Deployment ### Multi-Node Disaggregated Deployment
To run the disaggregated deployment across multiple nodes, we need to launch the servers using MPI, pass the correct NATS and etcd endpoints to each server and update the LLMAPI disaggregated config file to use the correct endpoints. To run the disaggregated deployment across multiple nodes, we need to launch the servers using MPI, pass the correct NATS and etcd endpoints to each server and update the LLMAPI disaggregated config file to use the correct endpoints.
...@@ -251,9 +304,9 @@ export NATS_SERVER="nats://node1:4222" ...@@ -251,9 +304,9 @@ export NATS_SERVER="nats://node1:4222"
export ETCD_ENDPOINTS="http://node1:2379,http://node2:2379" export ETCD_ENDPOINTS="http://node1:2379,http://node2:2379"
``` ```
3. Launch the workers from node1 or login node. WORLD_SIZE is similar to single node deployment. Update the `model.json` to point to the new disagg config file. 3. Launch the workers from node1 or login node. WORLD_SIZE is similar to single node deployment.
```bash ```bash
srun --mpi pmix -N NUM_NODES --ntasks WORLD_SIZE --ntasks-per-node=WORLD_SIZE --no-container-mount-home --overlap --container-image IMAGE --output batch_%x_%j.log --err batch_%x_%j.err --container-mounts PATH_TO_TRITON_DISTRIBUTED:/workspace --container-env=NATS_SERVER,ETCD_ENDPOINTS bash -c 'cd /workspace/examples/python_rs/llm/tensorrt_llm && python3 -m disaggregated.worker --engine_args model.json -c disaggregated/llmapi_disaggregated_configs/multi_node_config.yaml' & srun --mpi pmix -N NUM_NODES --ntasks WORLD_SIZE --ntasks-per-node=WORLD_SIZE --no-container-mount-home --overlap --container-image IMAGE --output batch_%x_%j.log --err batch_%x_%j.err --container-mounts PATH_TO_TRITON_DISTRIBUTED:/workspace --container-env=NATS_SERVER,ETCD_ENDPOINTS bash -c 'cd /workspace/examples/python_rs/llm/tensorrt_llm && python3 -m disaggregated.worker --engine_args llm_api_config.yaml -c disaggregated/llmapi_disaggregated_configs/multi_node_config.yaml' &
``` ```
Once the workers are launched, you should see the output similar to the following in the worker logs. Once the workers are launched, you should see the output similar to the following in the worker logs.
...@@ -270,25 +323,8 @@ Once the workers are launched, you should see the output similar to the followin ...@@ -270,25 +323,8 @@ Once the workers are launched, you should see the output similar to the followin
4. Launch the router from node1 or login node. 4. Launch the router from node1 or login node.
```bash ```bash
srun --mpi pmix -N 1 --ntasks 1 --ntasks-per-node=1 --overlap --container-image IMAGE --output batch_router_%x_%j.log --err batch_router_%x_%j.err --container-mounts PATH_TO_TRITON_DISTRIBUTED:/workspace --container-env=NATS_SERVER,ETCD_ENDPOINTS bash -c 'cd /workspace/examples/python_rs/llm/tensorrt_llm && python3 -m disaggregated.router -c disaggregated/llmapi_disaggregated_configs/multi_node_config.yaml' & srun --mpi pmix -N 1 --ntasks 1 --ntasks-per-node=1 --overlap --container-image IMAGE --output batch_router_%x_%j.log --err batch_router_%x_%j.err --container-mounts PATH_TO_TRITON_DISTRIBUTED:/workspace --container-env=NATS_SERVER,ETCD_ENDPOINTS bash -c 'cd /workspace/examples/python_rs/llm/tensorrt_llm && python3 -m disaggregated.router' &
``` ```
5. Send requests to the router. 5. Send requests to the router.
```bash The router will connect to the OAI compatible server. You can send requests to the router using the standard OAI format as shown in previous sections.
srun --mpi pmix -N 1 --ntasks 1 --ntasks-per-node=1 --overlap --container-image IMAGE --output batch_client_%x_%j.log --err batch_client_%x_%j.err --container-mounts PATH_TO_TRITON_DISTRIBUTED:/workspace --container-env=NATS_SERVER,ETCD_ENDPOINTS bash -c 'cd /workspace/examples/python_rs/llm/tensorrt_llm && python3 -m common.client --prompt "Describe the capital of France" --max-tokens 10 --temperature 0.5 --component router' &
```
Finally, you should see the output similar to the following in the client logs.
```
Annotated(data='and', event=None, comment=[], id=None)
Annotated(data='and its', event=None, comment=[], id=None)
Annotated(data='and its significance', event=None, comment=[], id=None)
Annotated(data='and its significance in', event=None, comment=[], id=None)
Annotated(data='and its significance in the', event=None, comment=[], id=None)
Annotated(data='and its significance in the country', event=None, comment=[], id=None)
Annotated(data="and its significance in the country'", event=None, comment=[], id=None)
Annotated(data="and its significance in the country's", event=None, comment=[], id=None)
Annotated(data="and its significance in the country's history", event=None, comment=[], id=None)
Annotated(data="and its significance in the country's history.", event=None, comment=[], id=None)
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import threading
from contextlib import asynccontextmanager
from typing import Any, Optional
from common.parser import LLMAPIConfig
from common.processor import ChatProcessor, CompletionsProcessor
from tensorrt_llm._torch import LLM
from tensorrt_llm.logger import logger
from transformers import AutoTokenizer
class BaseTensorrtLLMEngine:
def __init__(self, engine_config: LLMAPIConfig):
self._engine_config = engine_config
logger.info(f"Using LLM API config: {self._engine_config}")
# model name for chat processor
self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}")
# model for LLMAPI input
self._model = self._model_name
if self._engine_config.model_path:
self._model = self._engine_config.model_path
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_path
)
logger.info(f"Using model from path: {self._engine_config.model_path}")
else:
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_name
)
self._init_engine()
if self._engine_config.extra_args.get("tokenizer", None):
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.extra_args.get("tokenizer", None)
)
self.chat_processor = ChatProcessor(self._model_name, self._tokenizer)
self.completions_processor = CompletionsProcessor(self._model_name)
def _init_engine(self):
logger.info("Initializing engine")
# Run the engine in a separate thread running the AsyncIO event loop.
self._llm_engine: Optional[Any] = None
self._llm_engine_start_cv = threading.Condition()
self._llm_engine_shutdown_event = asyncio.Event()
self._event_thread = threading.Thread(
target=asyncio.run, args=(self._run_llm_engine(),)
)
self._event_thread.start()
with self._llm_engine_start_cv:
while self._llm_engine is None:
self._llm_engine_start_cv.wait()
# The 'threading.Thread()' will not raise the exception here should the engine
# failed to start, so the exception is passed back via the engine variable.
if isinstance(self._llm_engine, Exception):
e = self._llm_engine
logger.error(f"Failed to start engine: {e}")
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
raise e
async def _run_llm_engine(self):
# Counter to keep track of ongoing request counts.
self._ongoing_request_count = 0
@asynccontextmanager
async def async_llm_wrapper():
# Create LLM in a thread to avoid blocking
loop = asyncio.get_running_loop()
try:
llm = await loop.run_in_executor(
None,
lambda: LLM(model=self._model, **self._engine_config.to_dict()),
)
yield llm
finally:
if "llm" in locals():
# Run shutdown in a thread to avoid blocking
await loop.run_in_executor(None, llm.shutdown)
try:
async with async_llm_wrapper() as engine:
# Capture the engine event loop and make it visible to other threads.
self._event_loop = asyncio.get_running_loop()
# Signal the engine is started and make it visible to other threads.
with self._llm_engine_start_cv:
self._llm_engine = engine
self._llm_engine_start_cv.notify_all()
logger.info("Engine loaded and ready to serve...")
# Wait for the engine shutdown signal.
await self._llm_engine_shutdown_event.wait()
# Wait for the ongoing requests to complete.
while self._ongoing_request_count > 0:
logger.info(
"Awaiting remaining {} requests".format(
self._ongoing_request_count
)
)
await asyncio.sleep(1)
# Cancel all tasks in the event loop.
for task in asyncio.all_tasks(loop=self._event_loop):
if task is not asyncio.current_task():
task.cancel()
except Exception as e:
# Signal and pass the exception back via the engine variable if the engine
# failed to start. If the engine has started, re-raise the exception.
with self._llm_engine_start_cv:
if self._llm_engine is None:
self._llm_engine = e
self._llm_engine_start_cv.notify_all()
return
raise e
self._llm_engine = None
logger.info("Shutdown complete")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any, Dict, List, TypedDict, Union
from common.protocol import DisaggChatCompletionStreamResponse
from openai.types.chat import ChatCompletionMessageParam
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatMessage,
DeltaMessage,
FunctionCall,
ToolCall,
UsageInfo,
)
from transformers import AutoTokenizer
class ConversationMessage(TypedDict):
role: str
content: str
def parse_chat_message_content(
message: ChatCompletionMessageParam,
) -> Union[ConversationMessage, List[ConversationMessage], List[None]]:
role = message["role"]
content = message.get("content")
if content is None:
return []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)]
texts: List[str] = []
for part in content:
part_type = part["type"]
if part_type == "text":
text = part["text"] # type: ignore
texts.append(text)
else:
raise NotImplementedError(f"{part_type} is not supported")
text_prompt = "\n".join(texts)
return [ConversationMessage(role=role, content=text_prompt)]
class ChatProcessor:
def __init__(
self, model: str, tokenizer: AutoTokenizer, request: ChatCompletionRequest
):
self.model = model
self.tokenizer = tokenizer
self.request = request
self.num_choices = 1 if self.request.n is None else self.request.n
self.finish_reason_sent = [False] * self.num_choices
self.role = self._get_role(self.request)
def _get_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
role = "assistant"
else:
role = request.messages[-1]["role"]
return role
def _stream_usage_info(
self, request: ChatCompletionRequest, prompt_tokens: int, completion_tokens: int
):
if (
request.stream_options
and request.stream_options.include_usage
and request.stream_options.continuous_usage_stats
):
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
usage = None
return usage
def _create_logprobs(
self, token_ids: List[int], logprobs: List[float]
) -> ChatCompletionLogProbs:
assert len(token_ids) == len(
logprobs
), "token_ids and logprobs have different lengths"
content: List[ChatCompletionLogProbsContent] = []
for token_id, logprob in zip(token_ids, logprobs):
token = self.tokenizer.decode(token_id)
# returning multiple logprobs is not supported
first_logprob = ChatCompletionLogProbsContent(
token=token,
# NOTE: min logprob -9999.0 for probabilities extremely close to 0
logprob=max(logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")),
)
content.append(first_logprob)
chat_logprobs = ChatCompletionLogProbs(content=content)
return chat_logprobs
def get_chat_stream_response(
self,
request_id: str,
res: RequestOutput,
first_iteration: bool,
) -> DisaggChatCompletionStreamResponse:
def get_first_chat(
num_tokens: int, role: str | None = None, content: str | None = None
):
for i in range(self.num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role, content=content),
finish_reason=None,
)
chunk = DisaggChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion.chunk",
choices=[choice_data],
model=self.model,
)
chunk.usage = self._stream_usage_info(
self.request, num_tokens, completion_tokens=0
)
return chunk
prompt_tokens = len(res.prompt_token_ids)
if first_iteration:
return get_first_chat(prompt_tokens, role=self.role)
for output in res.outputs:
i = output.index
if self.finish_reason_sent[i]:
continue
delta_text = output.text_diff
if (
self.request.tool_choice
and type(self.request.tool_choice) is ChatCompletionNamedToolChoiceParam
):
delta_message = DeltaMessage(
tool_calls=[
ToolCall(
function=FunctionCall(
name=self.request.tool_choice.function.name,
arguments=delta_text,
)
)
]
)
else:
delta_message = DeltaMessage(content=delta_text)
choice = ChatCompletionResponseStreamChoice(
index=i, delta=delta_message, finish_reason=None
)
if self.request.logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = self._create_logprobs(token_ids, logprobs)
if output.finish_reason is not None:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
self.finish_reason_sent[i] = True
chunk = DisaggChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion.chunk",
choices=[choice],
model=self.model,
)
chunk.usage = self._stream_usage_info(
self.request, prompt_tokens, output.length
)
return chunk
def create_final_stream_response(
self,
request_id: str,
final_result: RequestOutput,
) -> DisaggChatCompletionStreamResponse:
prompt_tokens = len(final_result.prompt_token_ids)
completion_tokens = sum(output.length for output in final_result.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = DisaggChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion",
choices=[],
model=self.model,
usage=final_usage,
)
return final_usage_chunk
async def create_chat_response(
self,
request: ChatCompletionRequest,
conversation: List[Dict[str, Any]],
model: str,
promise: RequestOutput,
) -> ChatCompletionResponse:
await promise.aresult()
choices: List[ChatCompletionResponseChoice] = []
role = self._get_role(request)
for output in promise.outputs:
if request.tool_choice and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam
):
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text,
)
)
],
)
else:
message = ChatMessage(role=role, content=output.text)
choice = ChatCompletionResponseChoice(
index=output.index,
message=message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
if request.logprobs:
choice.logprobs = self._create_logprobs(
output.token_ids, output.logprobs
)
choices.append(choice)
if request.echo:
last_msg_content = ""
if (
conversation
and conversation[-1].get("content")
and conversation[-1].get("role") == role
):
last_msg_content = conversation[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(promise.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in promise.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
model=model,
choices=choices,
usage=usage,
)
return response
...@@ -14,75 +14,93 @@ ...@@ -14,75 +14,93 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import json
import os import os
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
# Define the expected keys for each config import yaml
# TODO: Add more keys as needed from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
PYTORCH_CONFIG_KEYS = { from tensorrt_llm.llmapi import KvCacheConfig
"use_cuda_graph",
"cuda_graph_batch_sizes",
"cuda_graph_max_batch_size", @dataclass
"cuda_graph_padding_enabled", class LLMAPIConfig:
"enable_overlap_scheduler", def __init__(
"kv_cache_dtype", self,
"torch_compile_enabled", model_name: str,
"torch_compile_fullgraph", model_path: str | None = None,
"torch_compile_inductor_enabled", pytorch_backend_config: PyTorchConfig | None = None,
} kv_cache_config: KvCacheConfig | None = None,
**kwargs,
LLM_ENGINE_KEYS = { ):
"model", self.model_name = model_name
"tokenizer", self.model_path = model_path
"tokenizer_model", self.pytorch_backend_config = pytorch_backend_config
"skip_tokenizer_init", self.kv_cache_config = kv_cache_config
"trust_remote_code", self.extra_args = kwargs
"tensor_parallel_size",
"dtype", def to_dict(self) -> Dict[str, Any]:
"revision", data = {
"tokenizer_revision", "pytorch_backend_config": self.pytorch_backend_config,
"speculative_model", "kv_cache_config": self.kv_cache_config,
"enable_chunked_prefill", }
} if self.extra_args:
data.update(self.extra_args)
return data
def _get_llm_args(args_dict):
# Validation checks def update_sub_configs(self, other_config: Dict[str, Any]):
for k, v in args_dict.items(): if "pytorch_backend_config" in other_config:
if ( self.pytorch_backend_config = PyTorchConfig(
k not in LLM_ENGINE_KEYS **other_config["pytorch_backend_config"]
and k not in PYTORCH_CONFIG_KEYS )
and k != "copyright" self.extra_args.pop("pytorch_backend_config", None)
):
raise ValueError(f"Unrecognized key in --engine_args file: {k}") if "kv_cache_config" in other_config:
self.kv_cache_config = KvCacheConfig(**other_config["kv_cache_config"])
pytorch_config_args = { self.extra_args.pop("kv_cache_config", None)
k: v for k, v in args_dict.items() if k in PYTORCH_CONFIG_KEYS and v is not None
}
llm_engine_args = { def _get_llm_args(engine_config):
k: v for k, v in args_dict.items() if k in LLM_ENGINE_KEYS and v is not None # Only do model validation checks and leave other checks to LLMAPI
} if "model_name" not in engine_config:
if "model" not in llm_engine_args:
raise ValueError("Model name is required in the TRT-LLM engine config.") raise ValueError("Model name is required in the TRT-LLM engine config.")
if os.path.exists(llm_engine_args["model"]):
llm_engine_args["model"] = Path(llm_engine_args["model"])
return (pytorch_config_args, llm_engine_args) if engine_config.get("model_path", ""):
if os.path.exists(engine_config.get("model_path", "")):
engine_config["model_path"] = Path(engine_config["model_path"])
else:
raise ValueError(f"Model path {engine_config['model_path']} does not exist")
model_name = engine_config["model_name"]
model_path = engine_config.get("model_path", None)
engine_config.pop("model_name")
engine_config.pop("model_path", None)
# Store all other args as kwargs
llm_api_config = LLMAPIConfig(
model_name=model_name,
model_path=model_path,
**engine_config,
)
# Parse supported sub configs and remove from kwargs
llm_api_config.update_sub_configs(engine_config)
return llm_api_config
def _init_engine_args(engine_args_filepath): def _init_engine_args(engine_args_filepath):
"""Initialize engine arguments from config file.""" """Initialize engine arguments from config file."""
if not os.path.isfile(engine_args_filepath): if not os.path.isfile(engine_args_filepath):
raise ValueError( raise ValueError(
f"'{engine_args_filepath}' containing TRT-LLM engine args must be provided in when launching the worker" "'YAML file containing TRT-LLM engine args must be provided in when launching the worker."
) )
try: try:
with open(engine_args_filepath) as file: with open(engine_args_filepath) as file:
trtllm_engine_config = json.load(file) trtllm_engine_config = yaml.safe_load(file)
except json.JSONDecodeError as e: except yaml.YAMLError as e:
raise RuntimeError(f"Failed to parse engine config: {e}") raise RuntimeError(f"Failed to parse engine config: {e}")
return _get_llm_args(trtllm_engine_config) return _get_llm_args(trtllm_engine_config)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import time
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
List,
Tuple,
TypedDict,
Union,
)
from common.protocol import (
DisaggCompletionResponseStreamChoice,
DisaggCompletionStreamResponse,
DisaggregatedTypeConverter,
)
from openai.types.chat import ChatCompletionMessageParam
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
DeltaMessage,
FunctionCall,
ToolCall,
UsageInfo,
)
from transformers import AutoTokenizer
logger.set_level("debug")
class ConversationMessage(TypedDict):
role: str
content: str
def parse_chat_message_content(
message: ChatCompletionMessageParam,
) -> Union[ConversationMessage, List[ConversationMessage], List[None]]:
role = message["role"]
content = message.get("content")
if content is None:
return []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)]
texts: List[str] = []
for part in content:
part_type = part["type"]
if part_type == "text":
text = part["text"] # type: ignore
texts.append(text)
else:
raise NotImplementedError(f"{part_type} is not supported")
text_prompt = "\n".join(texts)
return [ConversationMessage(role=role, content=text_prompt)]
class ChatProcessor:
def __init__(self, model: str, tokenizer: AutoTokenizer):
self.model = model
self.tokenizer = tokenizer
def _get_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
role = "assistant"
else:
role = request.messages[-1]["role"]
return role
def _stream_usage_info(
self, request: ChatCompletionRequest, prompt_tokens: int, completion_tokens: int
):
if (
request.stream_options
and request.stream_options.include_usage
and request.stream_options.continuous_usage_stats
):
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
usage = None
return usage
def _create_logprobs(
self, token_ids: List[int], logprobs: List[float]
) -> ChatCompletionLogProbs:
assert len(token_ids) == len(
logprobs
), "token_ids and logprobs have different lengths"
content: List[ChatCompletionLogProbsContent] = []
for token_id, logprob in zip(token_ids, logprobs):
token = self.tokenizer.decode(token_id)
# returning multiple logprobs is not supported
first_logprob = ChatCompletionLogProbsContent(
token=token,
logprob=max(logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")),
)
content.append(first_logprob)
chat_logprobs = ChatCompletionLogProbs(content=content)
return chat_logprobs
async def _chat_stream_generator(
self,
request: ChatCompletionRequest,
request_id: str,
conversation: List[Dict[str, Any]],
promise: RequestOutput,
) -> AsyncGenerator[str, None]:
first_iteration = True
num_choices = 1 if request.n is None else request.n
finish_reason_sent = [False] * num_choices
role = self._get_role(request)
def yield_first_chat(
num_tokens: int, role: str | None = None, content: str | None = None
):
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role, content=content),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion.chunk",
choices=[choice_data],
model=self.model,
)
chunk.usage = self._stream_usage_info(request, num_tokens, 0)
data = chunk.model_dump_json(exclude_unset=True)
return data
async for res in promise:
prompt_tokens = len(res.prompt_token_ids)
if first_iteration:
yield f"data: {yield_first_chat(prompt_tokens, role=role)} \n\n"
if request.echo:
last_msg_content = ""
if (
conversation
and conversation[-1].get("content")
and conversation[-1].get("role") == role
):
last_msg_content = conversation[-1]["content"]
if last_msg_content:
yield f"data: {yield_first_chat(prompt_tokens, content=last_msg_content)}\n\n"
first_iteration = False
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_text = output.text_diff
if (
request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
):
delta_message = DeltaMessage(
tool_calls=[
ToolCall(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text,
)
)
]
)
else:
delta_message = DeltaMessage(content=delta_text)
choice = ChatCompletionResponseStreamChoice(
index=i, delta=delta_message, finish_reason=None
)
if request.logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = self._create_logprobs(token_ids, logprobs)
if output.finish_reason is not None:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
finish_reason_sent[i] = True
chunk = ChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion.chunk",
choices=[choice],
model=self.model,
)
chunk.usage = self._stream_usage_info(
request, prompt_tokens, output.length
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
if request.stream_options and request.stream_options.include_usage:
completion_tokens = sum(output.length for output in promise.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
created=int(time.time()),
object="chat.completion",
choices=[],
model=self.model,
usage=final_usage,
)
final_usage_data = final_usage_chunk.model_dump_json()
yield f"data: {final_usage_data}\n\n"
yield "data: [DONE]\n\n"
async def stream_response(
self,
request: ChatCompletionRequest,
request_id: str,
conversation: List[Dict[str, Any]],
promise: RequestOutput,
) -> AsyncGenerator[str, None]:
assert request.stream, "Only stream is supported"
async for raw_response in self._chat_stream_generator(
request, request_id, conversation, promise
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
async def create_chat_response(
self,
request: ChatCompletionRequest,
conversation: List[Dict[str, Any]],
model: str,
promise: RequestOutput,
) -> ChatCompletionResponse:
await promise.aresult()
choices: List[ChatCompletionResponseChoice] = []
role = self._get_role(request)
for output in promise.outputs:
if request.tool_choice and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam
):
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text,
)
)
],
)
else:
message = ChatMessage(role=role, content=output.text)
choice = ChatCompletionResponseChoice(
index=output.index,
message=message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
if request.logprobs:
choice.logprobs = self._create_logprobs(
output.token_ids, output.logprobs
)
choices.append(choice)
if request.echo:
last_msg_content = ""
if (
conversation
and conversation[-1].get("content")
and conversation[-1].get("role") == role
):
last_msg_content = conversation[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(promise.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in promise.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
model=model,
choices=choices,
usage=usage,
)
return response
def merge_promises(
promises: List[RequestOutput],
) -> AsyncIterator[Tuple[int, RequestOutput]]:
outputs = asyncio.Queue() # type: ignore
finished = [False] * len(promises)
async def producer(i: int, promise: RequestOutput):
async for output in promise:
await outputs.put((i, output))
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, promise)) for i, promise in enumerate(promises)
]
async def consumer():
while not all(finished) or not outputs.empty():
item = await outputs.get()
yield item
await asyncio.gather(*_tasks)
return consumer()
class CompletionsProcessor:
def __init__(self, model: str):
self.model = model
def _post_process(self, request, prompt_idx, num_choices, requst_output):
res = []
echoed = [False] * num_choices
num_repsonse_per_request = 1 if request.n is None else request.n
for gen_idx, output in enumerate(requst_output.outputs):
response_idx = prompt_idx * num_repsonse_per_request + gen_idx
delta_text = output.text_diff
if request.echo and not echoed[response_idx]:
delta_text = request.prompt + delta_text
echoed[response_idx] = True
choice = DisaggCompletionResponseStreamChoice(
index=response_idx,
text=delta_text,
stop_reason=output.stop_reason,
finish_reason=output.finish_reason,
)
if output.disaggregated_params is not None:
choice.disaggregated_params = (
DisaggregatedTypeConverter.to_oai_disaggregated_params(
output.disaggregated_params
)
)
chunk = DisaggCompletionStreamResponse(
model=self.model,
choices=[choice],
)
res.append(chunk.model_dump_json())
return res
async def create_completion_generator(
self,
request: CompletionRequest,
generator: AsyncIterator[Tuple[int, RequestOutput]],
num_choices: int,
):
async for prompt_idx, requst_output in generator:
pp_res = self._post_process(request, prompt_idx, num_choices, requst_output)
for _p in pp_res:
yield _p
async def create_completion_response(
self,
request: CompletionRequest,
generator: AsyncIterator[Tuple[int, RequestOutput]],
num_choices: int,
):
choices = [None] * num_choices
num_repsonse_per_request = 1 if request.n is None else request.n
num_prompt_tokens = num_gen_tokens = 0
async for prompt_idx, request_output in generator:
num_prompt_tokens += len(request_output.prompt_token_ids)
for gen_idx, output in enumerate(request_output.outputs):
num_gen_tokens += len(output.token_ids)
output_text = output.text
if request.echo:
output_text = request_output.prompt + output_text
idx = prompt_idx * num_repsonse_per_request + gen_idx
disaggregated_params = CompletionResponseChoice.to_disaggregated_params(
output.disaggregated_params
)
choice = CompletionResponseChoice(
index=idx,
text=output_text,
stop_reason=output.stop_reason,
finish_reason=output.finish_reason,
disaggregated_params=disaggregated_params,
)
choices[idx] = choice
usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_gen_tokens,
total_tokens=num_gen_tokens + num_prompt_tokens,
)
response = CompletionResponse(
model=self.model,
choices=choices,
usage=usage_info,
)
return response
...@@ -13,23 +13,93 @@ ...@@ -13,23 +13,93 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pydantic import BaseModel import base64
from tensorrt_llm.llmapi import DisaggregatedParams import time
import uuid
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
CompletionResponseStreamChoice,
DisaggregatedParams,
UsageInfo,
)
class Request(BaseModel): class Request(BaseModel):
prompt: str prompt: str
sampling_params: dict sampling_params: dict
streaming: bool = True streaming: bool
class DisaggregatedTypeConverter:
@staticmethod
def to_llm_disaggregated_params(
disaggregated_params: DisaggregatedParams,
) -> LlmDisaggregatedParams:
if disaggregated_params is None:
return None
else:
opaque_state = (
base64.b64decode(disaggregated_params.encoded_opaque_state)
if disaggregated_params.encoded_opaque_state is not None
else None
)
return LlmDisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=opaque_state,
)
@staticmethod
def to_oai_disaggregated_params(
tllm_disagg_params: LlmDisaggregatedParams,
) -> DisaggregatedParams:
if tllm_disagg_params is None:
return None
else:
encoded_opaque_state = (
base64.b64encode(tllm_disagg_params.opaque_state).decode("utf-8")
if tllm_disagg_params is not None
else None
)
return DisaggregatedParams(
request_type=tllm_disagg_params.request_type,
first_gen_tokens=tllm_disagg_params.first_gen_tokens,
ctx_request_id=tllm_disagg_params.ctx_request_id,
encoded_opaque_state=encoded_opaque_state,
)
# Chat Completions
class DisaggChatCompletionRequest(ChatCompletionRequest):
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid4().hex)}")
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
class DisaggChatCompletionStreamResponse(ChatCompletionStreamResponse):
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
class Response(BaseModel): ## Completions
text: str
class DisaggregatedRequest(Request): class DisaggCompletionResponseStreamChoice(CompletionResponseStreamChoice):
disaggregated_params: dict = {} disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
class DisaggregatedResponse(Response): class DisaggCompletionStreamResponse(BaseModel):
disaggregated_params: DisaggregatedParams = {} model_config = ConfigDict(extra="forbid")
id: str = Field(default_factory=lambda: f"cmpl-{str(uuid.uuid4().hex)}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[DisaggCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
...@@ -13,22 +13,29 @@ ...@@ -13,22 +13,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# This will overwrite the llm_api_config.yaml
hostname: localhost hostname: localhost
port: 8000 port: 8000
context_servers: context_servers:
num_instances: 2 num_instances: 1
tensor_parallel_size: 2 tensor_parallel_size: 1
moe_expert_parallel_size: 1
kv_cache_config: kv_cache_config:
free_gpu_memory_fraction: 0.2 free_gpu_memory_fraction: 0.45
pytorch_backend_config:
enable_overlap_scheduler: false
use_cuda_graph: false
urls: urls:
- "localhost:8001" - "localhost:8001"
- "localhost:8002"
generation_servers: generation_servers:
num_instances: 2 num_instances: 1
tensor_parallel_size: 2 tensor_parallel_size: 1
moe_expert_parallel_size: 1
kv_cache_config: kv_cache_config:
free_gpu_memory_fraction: 0.2 free_gpu_memory_fraction: 0.95
pytorch_backend_config:
enable_overlap_scheduler: true
use_cuda_graph: true
urls: urls:
- "localhost:8003" - "localhost:8002"
- "localhost:8004" \ No newline at end of file
\ No newline at end of file
...@@ -13,20 +13,18 @@ ...@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import asyncio import asyncio
import copy import copy
from dataclasses import asdict import json
import uvloop import uvloop
from common.protocol import DisaggregatedRequest, DisaggregatedResponse, Response from common.protocol import (
from tensorrt_llm.llmapi import DisaggregatedParams DisaggChatCompletionRequest,
from tensorrt_llm.llmapi.disagg_utils import ( DisaggChatCompletionStreamResponse,
CtxGenServerConfig, DisaggCompletionStreamResponse,
parse_disagg_config_file,
) )
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import CompletionRequest, DisaggregatedParams
from triton_distributed.runtime import ( from triton_distributed.runtime import (
DistributedRuntime, DistributedRuntime,
...@@ -34,56 +32,116 @@ from triton_distributed.runtime import ( ...@@ -34,56 +32,116 @@ from triton_distributed.runtime import (
triton_worker, triton_worker,
) )
logger.set_level("info") logger.set_level("debug")
class Router: class Router:
def __init__(self, ctx_client, gen_client): def __init__(
self.ctx_server_idx = 0 self,
self.gen_server_idx = 0 ctx_chat_client,
self.ctx_client = ctx_client gen_chat_client,
self.gen_client = gen_client ctx_completion_client,
gen_completion_client,
):
self.ctx_chat_client = ctx_chat_client
self.gen_chat_client = gen_chat_client
self.ctx_completion_client = ctx_completion_client
self.gen_completion_client = gen_completion_client
logger.info("INITIALIZED ROUTER") logger.info("INITIALIZED ROUTER")
@triton_endpoint(DisaggregatedRequest, Response) async def _get_ctx_resp(self, request, ctx_client):
async def generate(self, request): logger.debug(f"Received request {request}")
gen_req = copy.deepcopy(request)
# Send request to context server request.max_tokens = 1
request.disaggregated_params = asdict( request.disaggregated_params = DisaggregatedParams(request_type="context_only")
DisaggregatedParams(request_type="context_only") logger.debug(f"[router] Sending request to context server: {request}")
)
request.sampling_params["max_tokens"] = 1
ctx_resp = [ ctx_resp = [
resp resp
async for resp in await self.ctx_client.round_robin( async for resp in await ctx_client.round_robin(request.model_dump_json())
request.model_dump_json()
)
] ]
if len(ctx_resp) > 1: if len(ctx_resp) > 1:
raise ValueError( raise ValueError(
"Context server returned more than one response. This is currently not supported in disaggregated server." "Context server returned more than one response. This is currently not supported in disaggregated server."
) )
logger.debug(
f"[router] received response from context server: {ctx_resp[0].data()}"
)
return ctx_resp[0].data()
ctx_resp_obj = DisaggregatedResponse.parse_raw(ctx_resp[0].data()) # TODO (shreyasm): The only reason we cant further combine the two methods below is
if request.streaming: # because the disagg params are in different locations.
# When streaming, the context server returns the first token and the rest of the tokens # Disagg params should be in under the choices field in the response object.
# are returned in the generation server. We are return the first token here to ensure # This is the case for completions but not for chat.
# low TTFT
# NOTE: this might change in the future if trtllm context server returns raw tokens
yield ctx_resp_obj.text
gen_req.disaggregated_params = ctx_resp_obj.disaggregated_params @triton_endpoint(CompletionRequest, DisaggCompletionStreamResponse)
async def generate_completion(self, request):
# These settings are needed to satisfy request checks.
request.skip_special_tokens = False
request.add_special_tokens = False
request.spaces_between_special_tokens = False
gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(request, self.ctx_completion_client)
ctx_resp_obj = DisaggCompletionStreamResponse.model_validate(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate(
ctx_resp_obj.choices[0].disaggregated_params
)
gen_req.disaggregated_params.request_type = "generation_only" gen_req.disaggregated_params.request_type = "generation_only"
async for response in await self.gen_client.round_robin( if request.stream:
yield json.loads(
ctx_resp_obj.model_dump_json(
exclude_unset=True, exclude={"disaggregated_params"}
)
)
logger.debug(f"[router] Sending request to generation server: {gen_req}")
async for response in await self.gen_completion_client.round_robin(
gen_req.model_dump_json() gen_req.model_dump_json()
): ):
yield response.data() gen_resp_obj = DisaggCompletionStreamResponse.model_validate(
response.data()
)
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
@triton_endpoint(DisaggChatCompletionRequest, DisaggChatCompletionStreamResponse)
async def generate_chat(self, request):
# These settings are needed to satisfy request checks.
request.skip_special_tokens = False
request.add_special_tokens = False
request.spaces_between_special_tokens = False
gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(request, self.ctx_chat_client)
ctx_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate(
ctx_resp_obj.disaggregated_params
)
gen_req.disaggregated_params.request_type = "generation_only"
if request.stream:
yield json.loads(
ctx_resp_obj.model_dump_json(
exclude_unset=True, exclude={"disaggregated_params"}
)
)
logger.debug(f"[router] Sending request to generation server: {gen_req}")
async for response in await self.gen_chat_client.round_robin(
gen_req.model_dump_json()
):
gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate(
response.data()
)
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
@triton_worker() @triton_worker()
async def worker(runtime: DistributedRuntime, server_configs: list[CtxGenServerConfig]): async def worker(runtime: DistributedRuntime):
""" """
Instantiate a `backend` component and serve the `generate` endpoint Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints A `Component` can serve multiple endpoints
...@@ -91,36 +149,42 @@ async def worker(runtime: DistributedRuntime, server_configs: list[CtxGenServerC ...@@ -91,36 +149,42 @@ async def worker(runtime: DistributedRuntime, server_configs: list[CtxGenServerC
component = runtime.namespace("triton-init").component("router") component = runtime.namespace("triton-init").component("router")
await component.create_service() await component.create_service()
ctx_client = ( ctx_completion_client = (
await runtime.namespace("triton-init") await runtime.namespace("triton-init")
.component("tensorrt-llm-ctx") .component("tensorrt-llm-ctx")
.endpoint("generate") .endpoint("completions")
.client() .client()
) )
gen_client = ( gen_completion_client = (
await runtime.namespace("triton-init") await runtime.namespace("triton-init")
.component("tensorrt-llm-gen") .component("tensorrt-llm-gen")
.endpoint("generate") .endpoint("completions")
.client()
)
ctx_chat_client = (
await runtime.namespace("triton-init")
.component("tensorrt-llm-ctx")
.endpoint("chat/completions")
.client()
)
gen_chat_client = (
await runtime.namespace("triton-init")
.component("tensorrt-llm-gen")
.endpoint("chat/completions")
.client() .client()
) )
endpoint = component.endpoint("generate") completions_endpoint = component.endpoint("completions")
await endpoint.serve_endpoint(Router(ctx_client, gen_client).generate) chat_endpoint = component.endpoint("chat/completions")
router = Router(
ctx_chat_client, gen_chat_client, ctx_completion_client, gen_completion_client
)
await asyncio.gather(
completions_endpoint.serve_endpoint(router.generate_completion),
chat_endpoint.serve_endpoint(router.generate_chat),
)
if __name__ == "__main__": if __name__ == "__main__":
uvloop.install() uvloop.install()
asyncio.run(worker())
parser = argparse.ArgumentParser()
parser.add_argument(
"--llmapi-disaggregated-config",
"-c",
type=str,
default="disaggregated/llmapi_disaggregated_configs/single_node_config.yaml",
help="Path to the llmapi disaggregated config file",
)
args = parser.parse_args()
disagg_config = parse_disagg_config_file(args.llmapi_disaggregated_config)
server_configs = disagg_config.server_configs
asyncio.run(worker(server_configs))
...@@ -15,21 +15,26 @@ ...@@ -15,21 +15,26 @@
import asyncio import asyncio
import json
import os import os
import threading import signal
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Tuple
import uvloop import uvloop
from common.parser import parse_tensorrt_llm_args from common.base_engine import BaseTensorrtLLMEngine
from common.protocol import DisaggregatedRequest, DisaggregatedResponse from common.disagg_processor import ChatProcessor, parse_chat_message_content
from common.parser import LLMAPIConfig, parse_tensorrt_llm_args
from common.processor import merge_promises
from common.protocol import (
DisaggChatCompletionRequest,
DisaggChatCompletionStreamResponse,
DisaggCompletionStreamResponse,
DisaggregatedTypeConverter,
)
from mpi4py.futures import MPICommExecutor from mpi4py.futures import MPICommExecutor
from mpi4py.MPI import COMM_WORLD from mpi4py.MPI import COMM_WORLD
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm._utils import set_mpi_comm from tensorrt_llm._utils import set_mpi_comm
from tensorrt_llm.llmapi import DisaggregatedParams, KvCacheConfig, MpiCommSession from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi import MpiCommSession
from tensorrt_llm.llmapi.disagg_utils import ( from tensorrt_llm.llmapi.disagg_utils import (
CtxGenServerConfig, CtxGenServerConfig,
DisaggServerConfig, DisaggServerConfig,
...@@ -37,177 +42,183 @@ from tensorrt_llm.llmapi.disagg_utils import ( ...@@ -37,177 +42,183 @@ from tensorrt_llm.llmapi.disagg_utils import (
split_world_comm, split_world_comm,
) )
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import CompletionRequest
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
logger.set_level("debug")
from triton_distributed.runtime import DistributedRuntime, triton_worker
logger.set_level("info") def update_args_from_disagg_config(
engine_config: LLMAPIConfig, server_config: CtxGenServerConfig
):
# Overwrite the LLM API config with the disaggregated config
# Allows for different configs for context and generation servers
engine_config.extra_args.update(**server_config.other_args)
engine_config.update_sub_configs(server_config.other_args)
return engine_config
class TensorrtLLMEngine: class TensorrtLLMEngine(BaseTensorrtLLMEngine):
""" """
Request handler for the generate endpoint Request handler for the generate endpoint
""" """
def __init__( def __init__(
self, self,
engine_args: Tuple[Dict[str, Any], Dict[str, Any]], engine_config: LLMAPIConfig,
disagg_config: DisaggServerConfig, disagg_config: DisaggServerConfig,
instance_idx: int, instance_idx: int,
sub_comm, sub_comm,
): ):
self.pytorch_config_args, self.llm_engine_args = engine_args
self.disagg_config = disagg_config self.disagg_config = disagg_config
self.instance_idx = instance_idx self.instance_idx = instance_idx
self.server_config: CtxGenServerConfig = disagg_config.server_configs[ self.server_config: CtxGenServerConfig = disagg_config.server_configs[
instance_idx instance_idx
] ]
self.mpi_session = MpiCommSession(sub_comm, n_workers=sub_comm.Get_size()) engine_config = update_args_from_disagg_config(
self._init_engine() engine_config, self.server_config
def _init_engine(self):
logger.info("Initializing engine")
# Run the engine in a separate thread running the AsyncIO event loop.
self._llm_engine: Optional[Any] = None
self._llm_engine_start_cv = threading.Condition()
self._llm_engine_shutdown_event = asyncio.Event()
self._event_thread = threading.Thread(
target=asyncio.run, args=(self._run_llm_engine(),)
) )
self._event_thread.start()
with self._llm_engine_start_cv:
while self._llm_engine is None:
self._llm_engine_start_cv.wait()
# The 'threading.Thread()' will not raise the exception here should the engine
# failed to start, so the exception is passed back via the engine variable.
if isinstance(self._llm_engine, Exception):
e = self._llm_engine
logger.error(f"Failed to start engine: {e}")
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
raise e
async def _run_llm_engine(self):
# Counter to keep track of ongoing request counts.
self._ongoing_request_count = 0
@asynccontextmanager
async def async_llm_wrapper():
# Create LLM in a thread to avoid blocking
loop = asyncio.get_running_loop()
try:
pytorch_config = PyTorchConfig(**self.pytorch_config_args)
# TODO: maybe add build config
llm = await loop.run_in_executor(
None,
lambda: LLM(
**self.llm_engine_args,
tensor_parallel_size=self.server_config.other_args.get(
"tensor_parallel_size", 1
),
pipeline_parallel_size=self.server_config.other_args.get(
"pipeline_parallel_size", 1
),
gpus_per_node=None,
trust_remote_code=True,
_mpi_session=self.mpi_session,
kv_cache_config=KvCacheConfig(
**self.server_config.other_args.get("kv_cache_config", {})
),
pytorch_backend_config=pytorch_config,
backend="pytorch",
),
)
yield llm
finally:
if "llm" in locals():
# Run shutdown in a thread to avoid blocking
await loop.run_in_executor(None, llm.shutdown)
try: # needed for disagg
async with async_llm_wrapper() as engine: self._mpi_session = MpiCommSession(sub_comm, n_workers=sub_comm.Get_size())
# Capture the engine event loop and make it visible to other threads. engine_config.extra_args["_mpi_session"] = self._mpi_session
self._event_loop = asyncio.get_running_loop() super().__init__(engine_config)
# Signal the engine is started and make it visible to other threads. @triton_endpoint(DisaggChatCompletionRequest, DisaggChatCompletionStreamResponse)
with self._llm_engine_start_cv: async def generate_chat(self, request):
self._llm_engine = engine if self._llm_engine is None:
self._llm_engine_start_cv.notify_all() raise RuntimeError("Engine not initialized")
logger.info("Engine loaded and ready to serve...") logger.debug(f"Received request: {request}")
chat_processor = ChatProcessor(self._model, self._tokenizer, request)
# Wait for the engine shutdown signal. self._ongoing_request_count += 1
await self._llm_engine_shutdown_event.wait()
# Wait for the ongoing requests to complete. try:
while self._ongoing_request_count > 0: conversation = []
logger.info( for message in request.messages:
"Awaiting remaining {} requests".format( conversation.extend(parse_chat_message_content(message))
self._ongoing_request_count tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
prompt: str = self._tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
sampling_params = request.to_sampling_params()
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
request.disaggregated_params
)
)
final_result = None
async for result in self._llm_engine.generate_async(
prompt,
sampling_params,
streaming=request.stream,
disaggregated_params=disaggregated_params,
):
final_result = result
logger.debug(f"Generated result: {result}")
if self.server_config.type == "ctx":
disaggregated_response = chat_processor.get_chat_stream_response(
request.id,
result,
first_iteration=True,
)
disaggregated_response.disaggregated_params = (
DisaggregatedTypeConverter.to_oai_disaggregated_params(
result.outputs[0].disaggregated_params
) )
) )
await asyncio.sleep(1) yield disaggregated_response.model_dump_json()
else:
yield chat_processor.get_chat_stream_response(
request.id,
result,
first_iteration=False,
).model_dump_json(
exclude_unset=True, exclude={"disaggregated_params"}
)
# Cancel all tasks in the event loop. if request.stream_options and request.stream_options.include_usage:
for task in asyncio.all_tasks(loop=self._event_loop): yield chat_processor.create_final_stream_response(
if task is not asyncio.current_task(): request.id,
task.cancel() final_result,
).model_dump_json(exclude_unset=True, exclude={"disaggregated_params"})
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e: except Exception as e:
# Signal and pass the exception back via the engine variable if the engine raise RuntimeError("Failed to generate: " + str(e))
# failed to start. If the engine has started, re-raise the exception.
with self._llm_engine_start_cv: self._ongoing_request_count -= 1
if self._llm_engine is None:
self._llm_engine = e @triton_endpoint(CompletionRequest, DisaggCompletionStreamResponse)
self._llm_engine_start_cv.notify_all() async def generate_completions(self, request):
return
raise e
self._llm_engine = None
logger.info("Shutdown complete")
async def generate(self, request):
if self._llm_engine is None: if self._llm_engine is None:
raise RuntimeError("Engine not initialized") raise RuntimeError("Engine not initialized")
self._ongoing_request_count += 1 self._ongoing_request_count += 1
logger.debug(f"Received request: {request}") logger.debug(f"[worker] Received completions request: {request}")
request = DisaggregatedRequest.parse_raw(request)
sampling_params = SamplingParams(**request.sampling_params) if not isinstance(request.prompt, str):
disaggregated_params = DisaggregatedParams(**request.disaggregated_params) # Check if it's a list and contains integers
if isinstance(request.prompt, list) and len(request.prompt) == 1:
# Opaque state is described as an additional state needing to be exchanged request.prompt = request.prompt[0]
# between context and gen instances elif not isinstance(request.prompt, list) or not all(
if disaggregated_params.opaque_state is not None: isinstance(x, int) for x in request.prompt
disaggregated_params.opaque_state = ( ):
disaggregated_params.opaque_state.encode("utf-8") raise ValueError(
.decode("unicode_escape") "Disaggregated server currently only supports single string prompt or list of integers in request"
.encode("latin1") )
sampling_params = request.to_sampling_params()
llm_disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
request.disaggregated_params
) )
)
async for response in self._llm_engine.generate_async( # only 1 prompt is supported for now
promise = self._llm_engine.generate_async(
request.prompt, request.prompt,
sampling_params, sampling_params,
streaming=request.streaming, streaming=request.stream,
disaggregated_params=disaggregated_params, disaggregated_params=llm_disaggregated_params,
): )
logger.debug(f"Generated response: {response}") generator = merge_promises([promise])
if self.server_config.type == "ctx": num_choices = 1 if request.n is None else request.n
yield DisaggregatedResponse( if request.stream:
text=response.outputs[0].text, response_generator = self.completions_processor.create_completion_generator(
disaggregated_params=response.outputs[0].disaggregated_params, request, generator, num_choices
).model_dump_json() )
else: async for response in response_generator:
yield response.outputs[0].text yield json.loads(response)
else:
raise RuntimeError("Non-streaming is not supported")
self._ongoing_request_count -= 1 self._ongoing_request_count -= 1
@triton_worker() @triton_worker()
async def worker( async def worker(
runtime: DistributedRuntime, runtime: DistributedRuntime,
engine_args: Tuple[Dict[str, Any], Dict[str, Any]], engine_config: LLMAPIConfig,
disagg_config: DisaggServerConfig, disagg_config: DisaggServerConfig,
instance_idx: int, instance_idx: int,
sub_comm, sub_comm,
...@@ -224,15 +235,18 @@ async def worker( ...@@ -224,15 +235,18 @@ async def worker(
) )
await component.create_service() await component.create_service()
endpoint = component.endpoint("generate") completions_endpoint = component.endpoint("completions")
await endpoint.serve_endpoint( chat_endpoint = component.endpoint("chat/completions")
TensorrtLLMEngine(engine_args, disagg_config, instance_idx, sub_comm).generate engine = TensorrtLLMEngine(engine_config, disagg_config, instance_idx, sub_comm)
await asyncio.gather(
completions_endpoint.serve_endpoint(engine.generate_completions),
chat_endpoint.serve_endpoint(engine.generate_chat),
) )
if __name__ == "__main__": if __name__ == "__main__":
uvloop.install() uvloop.install()
args, engine_args = parse_tensorrt_llm_args() args, engine_config = parse_tensorrt_llm_args()
if args.llmapi_disaggregated_config is None or not os.path.exists( if args.llmapi_disaggregated_config is None or not os.path.exists(
args.llmapi_disaggregated_config args.llmapi_disaggregated_config
...@@ -254,7 +268,7 @@ if __name__ == "__main__": ...@@ -254,7 +268,7 @@ if __name__ == "__main__":
logger.info(f"is_leader: {is_leader}, instance_idx: {instance_idx}") logger.info(f"is_leader: {is_leader}, instance_idx: {instance_idx}")
if is_leader: if is_leader:
asyncio.run(worker(engine_args, disagg_config, instance_idx, sub_comm)) asyncio.run(worker(engine_config, disagg_config, instance_idx, sub_comm))
else: else:
with MPICommExecutor(sub_comm) as executor: with MPICommExecutor(sub_comm) as executor:
if not is_leader and executor is not None: if not is_leader and executor is not None:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# In the case of disaggregated deployment, this config will apply to each server
# and will be overwritten by the disaggregated config file
model_name: "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model_path: null
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 10240
max_batch_size: 16
trust_remote_code: true
backend: pytorch
kv_cache_config:
free_gpu_memory_fraction: 0.95
pytorch_backend_config:
enable_overlap_scheduler: false
use_cuda_graph: false
\ No newline at end of file
{
"copyright": [
"SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.",
"SPDX-License-Identifier: Apache-2.0",
"Licensed under the Apache License, Version 2.0 (the \"License\");",
"you may not use this file except in compliance with the License.",
"You may obtain a copy of the License at",
"http://www.apache.org/licenses/LICENSE-2.0",
"Unless required by applicable law or agreed to in writing, software",
"distributed under the License is distributed on an \"AS IS\" BASIS,",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
"See the License for the specific language governing permissions and",
"limitations under the License."
],
"model":"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"tokenizer": null,
"tokenizer_model": null,
"skip_tokenizer_init": null,
"trust_remote_code": null,
"tensor_parallel_size": null,
"dtype": null,
"revision": null,
"tokenizer_revision": null,
"speculative_model": null,
"enable_chunked_prefill": null,
"use_cuda_graph": null,
"cuda_graph_batch_sizes": null,
"cuda_graph_max_batch_size": null,
"cuda_graph_padding_enabled": null,
"enable_overlap_scheduler": null,
"kv_cache_dtype": null,
"torch_compile_enabled": null,
"torch_compile_fullgraph": null,
"torch_compile_inductor_enabled": null
}
...@@ -15,17 +15,22 @@ ...@@ -15,17 +15,22 @@
import asyncio import asyncio
import threading import json
from contextlib import asynccontextmanager import signal
from typing import Any, Dict, Optional, Tuple import uuid
import uvloop import uvloop
from common.parser import parse_tensorrt_llm_args from common.base_engine import BaseTensorrtLLMEngine
from common.protocol import Request, Response from common.parser import LLMAPIConfig, parse_tensorrt_llm_args
from tensorrt_llm import SamplingParams from common.processor import merge_promises, parse_chat_message_content
from tensorrt_llm._torch import LLM from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
CompletionRequest,
CompletionStreamResponse,
)
from triton_distributed.runtime import ( from triton_distributed.runtime import (
DistributedRuntime, DistributedRuntime,
...@@ -33,127 +38,114 @@ from triton_distributed.runtime import ( ...@@ -33,127 +38,114 @@ from triton_distributed.runtime import (
triton_worker, triton_worker,
) )
logger.set_level("info") logger.set_level("debug")
class TensorrtLLMEngine: class TensorrtLLMEngine(BaseTensorrtLLMEngine):
""" """
Request handler for the generate endpoint Request handler for the generate endpoint
""" """
def __init__(self, engine_args: Tuple[Dict[str, Any], Dict[str, Any]]): def __init__(self, engine_config: LLMAPIConfig):
self.pytorch_config_args, self.llm_engine_args = engine_args super().__init__(engine_config)
self._init_engine()
def _init_engine(self):
logger.info("Initializing engine")
# Run the engine in a separate thread running the AsyncIO event loop.
self._llm_engine: Optional[Any] = None
self._llm_engine_start_cv = threading.Condition()
self._llm_engine_shutdown_event = asyncio.Event()
self._event_thread = threading.Thread(
target=asyncio.run, args=(self._run_llm_engine(),)
)
self._event_thread.start()
with self._llm_engine_start_cv:
while self._llm_engine is None:
self._llm_engine_start_cv.wait()
# The 'threading.Thread()' will not raise the exception here should the engine
# failed to start, so the exception is passed back via the engine variable.
if isinstance(self._llm_engine, Exception):
e = self._llm_engine
logger.error(f"Failed to start engine: {e}")
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
raise e
async def _run_llm_engine(self):
# Counter to keep track of ongoing request counts.
self._ongoing_request_count = 0
@asynccontextmanager
async def async_llm_wrapper():
# Create LLM in a thread to avoid blocking
loop = asyncio.get_running_loop()
try:
pytorch_config = PyTorchConfig(**self.pytorch_config_args)
llm = await loop.run_in_executor(
None,
lambda: LLM(
**self.llm_engine_args, pytorch_backend_config=pytorch_config
),
)
yield llm
finally:
if "llm" in locals():
# Run shutdown in a thread to avoid blocking
await loop.run_in_executor(None, llm.shutdown)
try: @triton_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async with async_llm_wrapper() as engine: async def generate_chat(self, request):
# Capture the engine event loop and make it visible to other threads. if self._llm_engine is None:
self._event_loop = asyncio.get_running_loop() raise RuntimeError("Engine not initialized")
# Signal the engine is started and make it visible to other threads. logger.debug(f"Received chat request: {request}")
with self._llm_engine_start_cv: request_id = str(uuid.uuid4())
self._llm_engine = engine self._ongoing_request_count += 1
self._llm_engine_start_cv.notify_all()
logger.info("Engine loaded and ready to serve...")
# Wait for the engine shutdown signal.
await self._llm_engine_shutdown_event.wait()
# Wait for the ongoing requests to complete.
while self._ongoing_request_count > 0:
logger.info(
"Awaiting remaining {} requests".format(
self._ongoing_request_count
)
)
await asyncio.sleep(1)
# Cancel all tasks in the event loop.
for task in asyncio.all_tasks(loop=self._event_loop):
if task is not asyncio.current_task():
task.cancel()
try:
conversation = []
for message in request.messages:
conversation.extend(parse_chat_message_content(message))
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
prompt: str = self._tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
sampling_params = request.to_sampling_params()
promise = self._llm_engine.generate_async(
prompt,
sampling_params,
streaming=request.stream,
)
# NOTE: somehow stream and non-stream is working with the same path
response_generator = self.chat_processor.stream_response(
request, request_id, conversation, promise
)
async for response in response_generator:
yield response
self._ongoing_request_count -= 1
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e: except Exception as e:
# Signal and pass the exception back via the engine variable if the engine raise RuntimeError("Failed to generate: " + str(e))
# failed to start. If the engine has started, re-raise the exception.
with self._llm_engine_start_cv: @triton_endpoint(CompletionRequest, CompletionStreamResponse)
if self._llm_engine is None: async def generate_completion(self, request):
self._llm_engine = e
self._llm_engine_start_cv.notify_all()
return
raise e
self._llm_engine = None
logger.info("Shutdown complete")
@triton_endpoint(Request, Response)
async def generate(self, request):
if self._llm_engine is None: if self._llm_engine is None:
raise RuntimeError("Engine not initialized") raise RuntimeError("Engine not initialized")
self._ongoing_request_count += 1 self._ongoing_request_count += 1
logger.debug(f"Received request: {request}") logger.debug(f"Received completion request: {request}")
sampling_params = SamplingParams(**request.sampling_params)
async for response in self._llm_engine.generate_async( if isinstance(request.prompt, str) or (
request.prompt, sampling_params, streaming=request.streaming isinstance(request.prompt, list) and isinstance(request.prompt[0], int)
): ):
logger.debug(f"Generated response: {response}") prompts = [request.prompt]
yield response.outputs[0].text else:
prompts = request.prompt
self._ongoing_request_count -= 1 promises = []
sampling_params = request.to_sampling_params()
try:
for prompt in prompts:
promise = self._llm_engine.generate_async(
prompt,
sampling_params,
streaming=request.stream,
)
promises.append(promise)
generator = merge_promises(promises)
num_choices = (
len(prompts) if request.n is None else len(prompts) * request.n
)
# NOTE: always send `stream: true` to the worker, and decide whether to aggregate or not before sending the response back to client.
response_generator = self.completions_processor.create_completion_generator(
request, generator, num_choices
)
async for response in response_generator:
yield json.loads(response)
self._ongoing_request_count -= 1
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
@triton_worker() @triton_worker()
async def worker( async def worker(runtime: DistributedRuntime, engine_config: LLMAPIConfig):
runtime: DistributedRuntime, engine_args: Tuple[Dict[str, Any], Dict[str, Any]]
):
""" """
Instantiate a `backend` component and serve the `generate` endpoint Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints A `Component` can serve multiple endpoints
...@@ -161,11 +153,18 @@ async def worker( ...@@ -161,11 +153,18 @@ async def worker(
component = runtime.namespace("triton-init").component("tensorrt-llm") component = runtime.namespace("triton-init").component("tensorrt-llm")
await component.create_service() await component.create_service()
endpoint = component.endpoint("generate") completions_endpoint = component.endpoint("completions")
await endpoint.serve_endpoint(TensorrtLLMEngine(engine_args).generate) chat_completions_endpoint = component.endpoint("chat/completions")
engine = TensorrtLLMEngine(engine_config)
await asyncio.gather(
completions_endpoint.serve_endpoint(engine.generate_completion),
chat_completions_endpoint.serve_endpoint(engine.generate_chat),
)
if __name__ == "__main__": if __name__ == "__main__":
uvloop.install() uvloop.install()
_, engine_args = parse_tensorrt_llm_args() args, engine_config = parse_tensorrt_llm_args()
asyncio.run(worker(engine_args)) asyncio.run(worker(engine_config))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment