Commit deb6c7e8 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat(llm): adding initial TRTLLM disaggregation support


Co-authored-by: default avatarnnshah1 <neelays@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 425be8ad
......@@ -20,7 +20,10 @@
**/.cache/*
**/*onnx*
# Engine must be allowed because code contains triton_distributed_engine.py
#**/*engine*
**/*tensorrtllm_engines*
**/*tensorrtllm_models*
**/*tensorrtllm_checkpoints*
**/*hf_downloads*
**/*pytorch_model*
**/*.pth*
**/*.pt
......
......@@ -96,7 +96,14 @@ of simple workers to load balance requests from a local work queue.
## LLM
[LLM](./examples/llm/vllm)
[TENSORRTLLM](./examples/llm/tensorrtllm)
An intermediate example expanding further on the concepts indroduced
in the Hello World example. In this example, we demonstrate
[Disaggregated Serving](https://arxiv.org/abs/2401.09670)
as an application of the components defined in Triton Distributed.
[VLLM](./examples/llm/vllm)
An intermediate example expanding further on the concepts indroduced
in the Hello World example. In this example, we demonstrate
......
......@@ -97,6 +97,7 @@ ARG TENSORRTLLM_FRAMEWORK
ENV FRAMEWORK_LD_LIBRARY_PATH=${TENSORRTLLM_FRAMEWORK:+/opt/tritonserver/backends/tensorrtllm/}
ENV LD_LIBRARY_PATH=${FRAMEWORK_LD_LIBRARY_PATH}:${LD_LIBRARY_PATH}
ENV TENSORRTLLM_BACKEND_COMMIT=$TENSORRTLLM_BACKEND_COMMIT
ENV TRTLLM_USE_MPI_KVCACHE=${TENSORRTLLM_FRAMEWORK:+"1"}
# TODO set VLLM Version
# ENV VLLM_VERSION
......@@ -153,6 +154,9 @@ RUN /workspace/icp/protos/gen_python.sh
# Sets pythonpath for python modules
ENV PYTHONPATH="${PYTHONPATH}:/workspace/icp/src/python:/workspace/worker/src/python:/workspace/examples:/opt/tritonserver/python/openai/openai_frontend"
# Enable system UCX
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
# Command and Entrypoint
CMD []
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
......@@ -34,6 +34,7 @@ def main(args):
data_plane_port=args.data_plane_port,
model_name=args.model_name,
tokenizer=args.tokenizer,
backend=args.backend,
)
# Attach TritonLLMEngine as the backbone for inference and model management
......
......@@ -128,6 +128,15 @@ def create_chat_response(
# Extract prompt from detokenized_outputs
cleaned_outputs = []
for detokenized_output in detokenized_outputs:
# FIXME - should be receiving array of strings by this point, not nested arrays
if isinstance(detokenized_output, np.ndarray):
detokenized_output = str(detokenized_output[0])
if not isinstance(detokenized_output, str):
raise RuntimeError(
f"ERROR: detokenized_output is not a string! {type(detokenized_output)=} | {detokenized_output=}"
)
# FIXME: Should this be handled by 'echo' param instead?
if detokenized_output.startswith(prompt):
cleaned_output = detokenized_output[len(prompt) :]
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
import numpy
import numpy as np
from llm.api_server.chat import ChatHandler, generate_sampling_params
from llm.api_server.connector import BaseTriton3Connector, InferenceResponse
from schemas.openai import CreateChatCompletionRequest
LOGGER = logging.getLogger(__name__)
# FIXME: Share request conversion logic where applicable
def generate_sampling_params_vllm(
request: CreateChatCompletionRequest,
non_supported_params: Optional[List[str]] = None,
) -> dict:
"""
Generate sampling params for vLLM from the request.
Args:
request: CreateChatCompletionRequest object.
Returns:
dict: Sampling params for vLLM.
"""
errors_message = ""
if request.logprobs:
errors_message += "The parameter 'logprobs' set to True is not supported. "
if request.tools and request.tools.type != "text":
errors_message += (
f"The parameter 'tools' type {request.tools.type} is not supported. "
)
if errors_message:
raise ValueError(errors_message)
if non_supported_params is None:
non_supported_params = [
"logit_bias",
"top_logprobs",
"tool_choice",
"user",
"service_tier",
]
sampling_params = generate_sampling_params(request, non_supported_params)
# NOTE: vLLM parameters (ex: top_k) not supported until added to schema
return sampling_params
class ChatHandlerTensorrtLLM(ChatHandler):
def __init__(
self, triton_connector: BaseTriton3Connector, model_name: str, tokenizer: str
):
super().__init__(triton_connector, tokenizer)
self._model_name = model_name
def translate_chat_inputs(
self, request: CreateChatCompletionRequest, request_id: str, prompt: str
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""Translate the chat completion request to inference request"""
if self._model_name is not None and self._model_name != request.model:
raise ValueError(
f"Model name mismatch: {self._model_name} != {request.model}"
)
inputs: Dict[str, np.ndarray | Any] = {}
sampling_params = generate_sampling_params_vllm(request)
parameters = {
"sampling_params": sampling_params,
"request_id": request_id,
# "prompt": prompt,
}
inputs["text_input"] = [[prompt]]
inputs["max_tokens"] = numpy.array(
[[sampling_params["max_tokens"]]], dtype=numpy.int32
)
return inputs, parameters
def translate_chat_outputs(
self, response: InferenceResponse, model_name: str
) -> Dict[str, Any]:
"""Translate the inference outputs to chat completion response"""
if "text" in response.parameters:
return {"model_output": [response.parameters["text"]]}
elif "text_output" in response.outputs:
print(response.outputs["text_output"])
return {"model_output": response.outputs["text_output"][0]}
return {}
......@@ -94,4 +94,8 @@ class ChatHandlerVllm(ChatHandler):
self, response: InferenceResponse, model_name: str
) -> Dict[str, Any]:
"""Translate the inference outputs to chat completion response"""
return {"model_output": [response.parameters["text"]]}
if "text" in response.parameters:
return {"model_output": [response.parameters["text"]]}
elif "text_output" in response.outputs:
return {"model_output": response.outputs["text_output"][0]}
return {}
......@@ -89,4 +89,13 @@ def parse_args():
help="Logging level",
)
parser.add_argument(
"--backend",
type=str,
required=False,
default="vllm",
choices=["vllm", "tensorrtllm"],
help="Backendtype",
)
return parser, parser.parse_args()
......@@ -16,7 +16,7 @@
import asyncio
import json
import typing
from typing import Optional
from typing import Any, Optional
import numpy as np
from llm.api_server.connector import (
......@@ -25,9 +25,9 @@ from llm.api_server.connector import (
InferenceResponse,
)
from llm.api_server.remote_connector import RemoteConnector
from tritonserver import DataType
from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.remote_tensor import RemoteTensor
class RemoteModelConnector(BaseTriton3Connector):
......@@ -129,30 +129,34 @@ class RemoteModelConnector(BaseTriton3Connector):
if isinstance(value, dict):
request.parameters[key] = "JSON:" + json.dumps(value)
store_inputs_in_request = set()
for k, v in request.inputs.items():
store_inputs_in_request.add(k)
results.append(
self._model.async_infer(
inputs=request.inputs,
parameters=request.parameters,
store_inputs_in_request=store_inputs_in_request,
)
)
for result in asyncio.as_completed(results):
responses = await result
outputs = {}
async for response in responses:
triton_response = response.to_model_infer_response(
self._connector._data_plane
)
outputs = {}
for output in triton_response.outputs:
remote_tensor = RemoteTensor(output, self._connector._data_plane)
for output_name, value in response.outputs.items():
try:
local_tensor = remote_tensor.local_tensor
numpy_tensor = np.from_dlpack(local_tensor)
output_value: Any = None
if value.data_type == DataType.BYTES:
output_value = [value.to_string_array()]
else:
output_value = np.from_dlpack(value)
finally:
# FIXME: This is a workaround for the issue that the remote tensor
# is released after connection is closed.
remote_tensor.__del__()
outputs[output.name] = numpy_tensor
# value.__del__()
pass
outputs[output_name] = output_value
infer_response = InferenceResponse(
outputs=outputs,
final=response.final,
......
......@@ -16,6 +16,7 @@ import asyncio
from typing import AsyncIterator
from engine.engine import LLMEngine
from llm.api_server.chat_tensorrtllm import ChatHandlerTensorrtLLM
from llm.api_server.chat_vllm import ChatHandlerVllm
from llm.api_server.remote_model_connector import RemoteModelConnector
from schemas.openai import (
......@@ -28,6 +29,32 @@ from schemas.openai import (
)
class TritonDistributedTensorrtLLMChatHandler(ChatHandlerTensorrtLLM):
def __init__(
self, triton_connector: RemoteModelConnector, model_name: str, tokenizer: str
):
super().__init__(triton_connector, model_name, tokenizer)
# Request / response format can vary between frontends, so allow override
# of adaptor functions accordingly.
def stream_response_adaptor(self, response_stream):
async def adaptor_stream():
async for response in response_stream():
if isinstance(response, Exception):
raise response
else:
# Already in SSE String format
yield response
return adaptor_stream
def response_adaptor(self, response):
return response
def exception_adaptor(self, exception):
raise exception
class TritonDistributedChatHandler(ChatHandlerVllm):
def __init__(
self, triton_connector: RemoteModelConnector, model_name: str, tokenizer: str
......@@ -62,6 +89,7 @@ class TritonDistributedEngine(LLMEngine):
data_plane_port: int,
model_name: str,
tokenizer: str,
backend: str,
):
self.triton_connector = RemoteModelConnector(
nats_url=nats_url,
......@@ -71,10 +99,15 @@ class TritonDistributedEngine(LLMEngine):
keep_dataplane_endpoints_open=True,
)
# FIXME: Consider supporting multiple or per-model tokenizers
self.request_handler = TritonDistributedChatHandler(
self.triton_connector, model_name, tokenizer
)
if not backend or backend == "vllm":
# FIXME: Consider supporting multiple or per-model tokenizers
self.request_handler = TritonDistributedChatHandler(
self.triton_connector, model_name, tokenizer
)
else:
self.request_handler = TritonDistributedTensorrtLLMChatHandler(
self.triton_connector, model_name, tokenizer
)
async def chat(
self, request: CreateChatCompletionRequest
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Disaggregated Serving with TensorRT-LLM
This example demonstrates **disaggregated serving** [^1] using Triton Distributed together with TensorRT-LLM engines. Disaggregated serving decouples the prefill (prompt encoding) and the decode (token generation) stages of large language model (LLM) inference into separate processes. This separation allows you to independently scale, optimize, and distribute resources for each stage.
In this example, you will deploy
- An **OpenAI-compatible API server** (which receives requests and streams responses).
- One or more **prefill workers** (for encoding the prompt).
- One or more **decode workers** (for generating tokens based on the encoded prompt).
## 1. Prerequisites
1. **GPU Availability**
This setup requires at least two GPUs:
- One GPU is typically used by the **prefill** process.
- Another GPU is used by the **decode** process.
In production systems with heavier loads, you will typically allocate more GPUs across multiple prefill and decode workers.
2. **NATS or Another Coordination Service**
Triton Distributed uses NATS by default for coordination and message passing. Make sure your environment has a running NATS service accessible via a valid `nats://<address>:<port>` endpoint. By default, examples assume `nats://localhost:4223`.
3. **HuggingFace**
- You need a HuggingFace account to download the model and set HF_TOKEN environment variable.
---
## 2. Building the Environment
The example is designed to run in a containerized environment using Triton Distributed, TensorRT-LLM, and associated dependencies. To build the container:
```bash
./container/build.sh --framework tensorrtllm
```
---
## 3. Starting the Deployment
Below is a minimal example of how to start each component of a disaggregated serving setup. The typical sequence is:
1. **Download and build model directories**
2. **Start the Context Worker(s) and Request Plane**
3. **Start the Generate Worker(s)**
1. **Start the API Server** (handles incoming requests and coordinates workers)
All components must be able to connect to the same request plane to coordinate.
### 3.1 Launch Interactive Environment
```bash
./container/run.sh --framework tensorrtllm -it
```
Note: all subsequent commands will be run in the same container for simplicity
Note: by default this command makes all gpu devices visible. Use flag
```
--gpus
```
to selectively make gpu devices visible.
### 3.2: Build model directories
```bash
export HF_TOKEN=<YOUR TOKEN>
python3 /workspace/examples/llm/tensorrtllm/scripts/prepare_models.py --tp-size 1 --model llama-3.1-8b-instruct --max-num-tokens 8192
```
After this you should see the following in `/workspace/examples/llm/tensorrtllm/operators`
```bash
|-- hf_downloads
| `-- llama-3.1-8b-instruct
| |-- config.json
| |-- generation_config.json
| |-- model-00001-of-00004.safetensors
| |-- model-00002-of-00004.safetensors
| |-- model-00003-of-00004.safetensors
| |-- model-00004-of-00004.safetensors
| |-- model.safetensors.index.json
| |-- original
| | `-- params.json
| |-- special_tokens_map.json
| |-- tokenizer.json
| `-- tokenizer_config.json
|-- tensorrtllm_checkpoints
| `-- llama-3.1-8b-instruct
| `-- NVIDIA_H100_NVL
| `-- TP_1
| |-- config.json
| `-- rank0.safetensors
|-- tensorrtllm_engines
| `-- llama-3.1-8b-instruct
| `-- NVIDIA_H100_NVL
| `-- TP_1
| |-- config.json
| `-- rank0.engine
|-- tensorrtllm_models
| `-- llama-3.1-8b-instruct
| `-- NVIDIA_H100_NVL
| `-- TP_1
| |-- context
| | |-- 1
| | | `-- model.py
| | `-- config.pbtxt
| |-- generate
| | |-- 1
| | | `-- model.py
| | `-- config.pbtxt
| |-- llama-3.1-8b-instruct
| | |-- 1
| | `-- config.pbtxt
| |-- postprocessing
| | |-- 1
| | | `-- model.py
| | `-- config.pbtxt
| |-- preprocessing
| | |-- 1
| | | `-- model.py
| | `-- config.pbtxt
| `-- tensorrt_llm
| |-- 1
| | `-- model.py
| `-- config.pbtxt
`-- triton_core_models
|-- mock
| |-- 1
| | `-- model.py
| `-- config.pbtxt
|-- simple_postprocessing
| |-- 1
| | `-- model.py
| `-- config.pbtxt
`-- simple_preprocessing
|-- 1
| `-- model.py
`-- config.pbtxt
```
### 3.3: Deployment Example
To start a basic deployment with 1 prefill and 1 decode worker:
```bash
export MODEL_NAME="llama-3.1-8b-instruct"
python3 /workspace/examples/llm/tensorrtllm/deploy/launch_workers.py \
--context-worker-count 1 \
--generate-worker-count 1 \
--model ${MODEL_NAME} \
--initialize-request-plane \
--disaggregated-serving \
--request-plane-uri ${HOSTNAME}:4222 &
```
Then start the OpenAI compatible API server
```bash
python3 -m llm.api_server \
--tokenizer meta-llama/Llama-3.1-8B-Instruct \
--request-plane-uri ${HOSTNAME}:4222 \
--api-server-host ${HOSTNAME} \
--model-name ${MODEL_NAME} &
```
### 3.4: Sending Requests
Once the API server is running (by default on `localhost:8000`), you can send OpenAI-compatible requests. For example:
```bash
curl ${HOSTNAME}:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama-3.1-8b-instruct",
"messages": [
{"role": "user", "content": "Why is Roger Federer the greatest tennis player of all time?"}
],
"temperature": 0,
"top_p": 0.95,
"max_tokens": 25,
"stream": true,
"n": 1,
"frequency_penalty": 0.0,
"stop": []
}'
```
The above request will return a streamed response with the model’s answer.
## 4. Teardown
To tear down a deployment during local development, you can either kill the
container or the kill the relevant processes involved in the deployment.
To kill the processes being run inside the container, you can run:
```bash
pkill -SIGINT -f python3
pkill -SIGINT -f nats-server
```
You will generally want to make sure you have a clean slate between
deployments to avoid any unexpected errors.
NOTE: If you have other unrelated processes in the environment with `python3`
in the name, the `pkill` command above will terminate them as well. In this
scenario, you could select specific process IDs and use the following command
instead for each process ID replacing `<pid>` below:
```
kill -9 <pid>
```
## Y. Known Issues & Limitations
1. **Tensor Parallelism Constraints**
- Currently limited to TP=1 for both prefill and decode workers
## Z. References
[^1]: Yinmin Zhong, Shengyu Liu, Junda Chen, Jianbo Hu, Yibo Zhu, Xuanzhe Liu, Xin Jin, and Hao
Zhang. Distserve: Disaggregating prefill and decoding for goodput-optimized large language
model serving. *arXiv:2401.09670v3 [cs.DC]*, 2024.
For more details on Triton Distributed, see the [Hello World example](../../hello_world/) and [Triton Inference Server documentation](https://github.com/triton-inference-server/server).
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import signal
import sys
import time
from pathlib import Path
from llm.tensorrtllm.operators.disaggregated_serving import DisaggregatedServingOperator
from llm.tensorrtllm.scripts.gpu_info import get_gpu_product_name
from triton_distributed.worker import (
OperatorConfig,
TritonCoreOperator,
Worker,
WorkerConfig,
)
from .parser import parse_args
deployment = None
def handler(signum, frame):
exit_code = 0
if deployment:
print("Stopping Workers")
exit_code = deployment.stop()
print(f"Workers Stopped Exit Code {exit_code}")
sys.exit(exit_code)
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for sig in signals:
try:
signal.signal(sig, handler)
except Exception:
pass
def _create_disaggregated_serving_op(name, args, max_inflight_requests):
model_repository = str(
Path(args.operator_repository) / "triton_core_models"
) # stores our simple pre/post processing
return OperatorConfig(
name=name,
implementation=DisaggregatedServingOperator,
max_inflight_requests=int(max_inflight_requests),
repository=model_repository,
)
def _create_triton_core_op(
name,
max_inflight_requests,
args,
):
# TODO: argparse repo
gpu_name = get_gpu_product_name()
return OperatorConfig(
name=name,
implementation=TritonCoreOperator,
max_inflight_requests=int(max_inflight_requests),
repository=str(
Path(args.operator_repository)
/ "tensorrtllm_models"
/ args.model
/ gpu_name
/ "TP_1"
),
parameters={
"store_outputs_in_response": True,
"config": {
"parameters": {
"participant_ids": {"string_value": f"{args.gpu_device_id}"},
"gpu_device_ids": {"string_value": f"{args.gpu_device_id}"},
}
},
},
)
def main(args):
if args.log_dir:
log_dir = Path(args.log_dir)
log_dir.mkdir(exist_ok=True)
worker_configs = []
if args.worker_type == "aggregate":
aggregate_op = _create_triton_core_op(
name=args.model, max_inflight_requests=1000, args=args
)
aggregate = WorkerConfig(
operators=[aggregate_op],
name=args.model,
request_plane_args=([], {"request_plane_uri": args.request_plane_uri}),
metrics_port=args.metrics_port,
)
worker_configs.append(aggregate)
# Context/Generate workers used for Disaggregated Serving
elif args.worker_type == "context":
prefill_op = _create_triton_core_op(
name="context",
max_inflight_requests=1000,
args=args,
)
prefill = WorkerConfig(
operators=[prefill_op],
name="context",
log_level=args.log_level,
metrics_port=args.metrics_port,
request_plane_args=([], {"request_plane_uri": args.request_plane_uri}),
)
worker_configs.append(prefill)
elif args.worker_type == "generate":
decoder_op = _create_triton_core_op(
name="generate",
max_inflight_requests=1000,
args=args,
)
decoder = WorkerConfig(
operators=[decoder_op],
name="generate",
log_level=args.log_level,
metrics_port=args.metrics_port,
request_plane_args=([], {"request_plane_uri": args.request_plane_uri}),
)
worker_configs.append(decoder)
elif args.worker_type == "disaggregated-serving":
prefill_decode_op = _create_disaggregated_serving_op(
name=args.model,
max_inflight_requests=1000,
args=args,
)
prefill_decode = WorkerConfig(
operators=[prefill_decode_op],
name=args.worker_name,
log_level=args.log_level,
metrics_port=args.metrics_port,
request_plane_args=([], {"request_plane_uri": args.request_plane_uri}),
)
worker_configs.append(prefill_decode)
print("Starting Worker")
for worker_config in worker_configs:
worker = Worker(worker_config)
print(f"worker: {worker}")
worker.start()
print("Worker started ... press Ctrl-C to Exit")
while True:
time.sleep(10)
if __name__ == "__main__":
args = parse_args()
main(args)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import signal
import subprocess
import sys
from pathlib import Path
from llm.tensorrtllm.deploy.parser import parse_args
deployment = None
def handler(signum, frame):
exit_code = 0
if deployment:
print("Stopping Workers")
exit_code = deployment.stop()
print(f"Workers Stopped Exit Code {exit_code}")
sys.exit(exit_code)
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for sig in signals:
try:
signal.signal(sig, handler)
except Exception:
pass
def _launch_mpi_workers(args):
if (
args.context_worker_count == 1
or args.generate_worker_count == 1
or args.aggregate_worker_count == 1
):
command = [
"mpiexec",
"--allow-run-as-root",
"--oversubscribe",
"--display-map",
"--verbose",
]
if args.log_dir:
WORKER_LOG_DIR = str(Path(args.log_dir) / "workers")
command += ["--output-filename", WORKER_LOG_DIR]
aggregate_gpus = args.context_worker_count + args.generate_worker_count
for index in range(args.context_worker_count):
starting_gpu = index * aggregate_gpus
command.extend(_context_cmd(args, starting_gpu))
command.append(":")
for index in range(args.generate_worker_count):
starting_gpu = index * aggregate_gpus + args.context_worker_count
command.extend(_generate_cmd(args, starting_gpu))
command.append(":")
for index in range(args.aggregate_worker_count):
starting_gpu = index * aggregate_gpus + args.context_worker_count
command.extend(_aggregate_cmd(args, starting_gpu))
command.append(":")
command = command[0:-1]
print(" ".join(command))
if args.dry_run:
return
env = os.environ.copy()
return subprocess.Popen(command, env=env, stdin=subprocess.DEVNULL)
else:
raise ValueError("Only supporting 1 worker each for now")
def _launch_disagg_model(args):
if not args.disaggregated_serving:
return
starting_gpu = 0
env = os.environ.copy()
command = _disaggregated_serving_cmd(args, starting_gpu)
print(" ".join(command))
if args.dry_run:
return
return subprocess.Popen(command, env=env, stdin=subprocess.DEVNULL)
def _launch_workers(args):
# Launch nats-server if requested by user for convenience, otherwise
# it can be started separately beforehand.
if args.initialize_request_plane:
_launch_nats_server(args)
# Launch TRT-LLM models via mpiexec in the same MPI WORLD
_launch_mpi_workers(args)
# Launch disaggregated serving "workflow" model to interface
# client-facing requests with Triton Distributed deployment.
_launch_disagg_model(args)
def _context_cmd(args, starting_gpu):
# Hard-coded worker name for internal communication,
# see tensorrtllm.deploy script
worker_name = "context"
command = [
"-np",
"1",
# FIXME: May need to double check this CUDA_VISIBLE_DEVICES
# and trtllm gpu_device_id/participant_id interaction
# "-x",
# f"CUDA_VISIBLE_DEVICES={starting_gpu}",
"python3",
"-m",
"llm.tensorrtllm.deploy",
"--worker-type",
"context",
"--worker-name",
worker_name,
"--model",
args.model,
"--gpu-device-id",
f"{starting_gpu}",
"--metrics-port",
"50000",
"--initialize-request-plane",
"--request-plane-uri",
f"{os.getenv('HOSTNAME')}:{args.nats_port}",
]
return command
def _generate_cmd(args, starting_gpu):
# Hard-coded worker name for internal communication
# see tensorrtllm.deploy script
worker_name = "generate"
command = [
"-np",
"1",
# FIXME: May need to double check this CUDA_VISIBLE_DEVICES
# and trtllm gpu_device_id/participant_id interaction
# "-x",
# f"CUDA_VISIBLE_DEVICES={starting_gpu}",
"python3",
"-m",
"llm.tensorrtllm.deploy",
"--worker-type",
"generate",
"--worker-name",
worker_name,
"--model",
args.model,
"--gpu-device-id",
f"{starting_gpu}",
"--metrics-port",
"50001",
"--request-plane-uri",
f"{os.getenv('HOSTNAME')}:{args.nats_port}",
]
return command
def _aggregate_cmd(args, starting_gpu):
# Hard-coded worker name for internal communication
# see tensorrtllm.deploy script
worker_name = "aggregate"
command = [
"-np",
"1",
# FIXME: May need to double check this CUDA_VISIBLE_DEVICES
# and trtllm gpu_device_id/participant_id interaction
# "-x",
# f"CUDA_VISIBLE_DEVICES={starting_gpu}",
"python3",
"-m",
"llm.tensorrtllm.deploy",
"--worker-type",
"aggregate",
"--worker-name",
worker_name,
"--model",
args.model,
"--gpu-device-id",
f"{starting_gpu}",
"--metrics-port",
"50001",
"--request-plane-uri",
f"{os.getenv('HOSTNAME')}:{args.nats_port}",
]
return command
def _disaggregated_serving_cmd(args, starting_gpu):
# NOTE: This worker gets the args --worker-name because it will
# receive the API-serving facing requests, and internally handle
# the disaggregation. So this worker name should match the one
# registered to the API Server.
command = [
# FIXME: Does this model need a GPU assigned to it?
# "-x",
# f"CUDA_VISIBLE_DEVICES={starting_gpu}",
"python3",
"-m",
"llm.tensorrtllm.deploy",
"--worker-type",
"disaggregated-serving",
"--metrics-port",
"50002",
"--model",
args.model,
"--worker-name",
args.worker_name,
"--request-plane-uri",
f"{os.getenv('HOSTNAME')}:{args.nats_port}",
]
return command
def _launch_nats_server(args, clear_store=True):
# FIXME: Use NatsServer object defined in icp package
store_dir = "/tmp/nats_store"
if clear_store:
shutil.rmtree(store_dir, ignore_errors=True)
command = [
"/usr/local/bin/nats-server",
"--jetstream",
"--port",
str(args.nats_port),
"--store_dir",
store_dir,
]
print(" ".join(command))
if args.dry_run:
return
env = os.environ.copy()
return subprocess.Popen(command, env=env, stdin=subprocess.DEVNULL)
if __name__ == "__main__":
args = parse_args()
_launch_workers(args)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
def parse_args():
parser = argparse.ArgumentParser(
description="Run an example of the TensorRT-LLM pipeline."
)
example_dir = Path(__file__).parent.absolute().parent.absolute()
default_operator_repository = example_dir.joinpath("operators")
default_log_dir = ""
parser.add_argument(
"--log-dir",
type=str,
default=str(default_log_dir),
help="log dir folder",
)
parser.add_argument(
"--initialize-request-plane",
default=False,
action="store_true",
help="Initialize the request plane, should only be done once per deployment",
)
parser.add_argument(
"--log-level", type=int, default=3, help="log level applied to all workers"
)
parser.add_argument(
"--request-plane-uri",
type=str,
default="nats://localhost:4222",
help="URI of request plane",
)
parser.add_argument(
"--nats-port",
type=int,
default=4222,
help="Port for NATS server",
)
parser.add_argument(
"--metrics-port",
type=int,
default=50000,
help="Metrics port",
)
parser.add_argument(
"--worker-type",
type=str,
default="aggregate",
help="Type of worker",
choices=["aggregate", "context", "generate", "disaggregated-serving"],
)
parser.add_argument("--gpu-device-id", type=int, default=0, help="gpu id")
parser.add_argument(
"--context-worker-count", type=int, default=0, help="Number of context workers"
)
parser.add_argument(
"--generate-worker-count",
type=int,
default=0,
help="Number of generate workers",
)
parser.add_argument(
"--aggregate-worker-count",
type=int,
required=False,
default=0,
help="Number of baseline workers",
)
parser.add_argument(
"--operator-repository",
type=str,
default=str(default_operator_repository),
help="Operator repository",
)
parser.add_argument(
"--worker-name",
type=str,
required=False,
default="llama",
help="Name of the worker",
)
parser.add_argument(
"--model",
type=str,
required=False,
default="llama-3.1-8b-instruct",
choices=[
"mock",
"llama-3.1-70b-instruct",
"llama-3.1-8b-instruct",
"llama-3-8b-instruct-generate",
"llama-3-8b-instruct-context",
"llama-3-8b-instruct",
"llama-3-8b-instruct-default",
"llama-3-70b-instruct-context",
"llama-3-70b-instruct-generate",
"llama-3-70b-instruct",
],
help="model to serve",
)
parser.add_argument(
"--ignore-eos",
action=argparse.BooleanOptionalAction,
required=False,
default=False,
help="Ignore EOS token when generating",
)
parser.add_argument(
"--dry-run",
action=argparse.BooleanOptionalAction,
required=False,
default=False,
help="Dry run the command",
)
parser.add_argument(
"--disaggregated-serving",
action=argparse.BooleanOptionalAction,
required=False,
default=False,
help="Enable disaggregated serving",
)
return parser.parse_args()
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llm.tensorrtllm.operators.disaggregated_serving import DisaggregatedServingOperator
__all__ = ["DisaggregatedServingOperator"]
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import numpy
from triton_distributed.worker import (
RemoteInferenceRequest,
RemoteOperator,
TritonCoreOperator,
)
class DisaggregatedServingOperator(TritonCoreOperator):
def __init__(
self,
name,
version,
triton_core,
request_plane,
data_plane,
parameters,
repository,
logger,
):
self._prefill = RemoteOperator("context", request_plane, data_plane)
self._decode = RemoteOperator("generate", request_plane, data_plane)
self._repository = repository
self._triton_core = triton_core
self._triton_core.register_model_repository(repository)
self._preprocess_model = self._triton_core.load("simple_preprocessing")
self._postprocess_model = self._triton_core.load("simple_postprocessing")
self._logger = logger
self._store_outputs_in_response = True
async def execute(self, requests: list[RemoteInferenceRequest]):
self._logger.debug("Executing DisaggregatedServing Request")
background_tasks = []
for request in requests:
task = asyncio.create_task(self._execute_request(request))
background_tasks.append(task)
try:
results = await asyncio.gather(*background_tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
self._logger.exception(
f"Running request execution failed: {result}"
)
else:
self._logger.debug(
f"Request execution completed with result: {result}"
)
except Exception as e:
self._logger.exception(f"Error during request execution: {e}")
async def _execute_request(self, request: RemoteInferenceRequest):
background_tasks = []
prefill_inputs = {}
sampling_params = {}
response_sender = request.response_sender()
"""Preprocessing"""
print(request)
if "text_input" in request.inputs:
query = request.inputs["text_input"].to_bytes_array()
elif "prompt" in request.inputs:
query = request.inputs["prompt"].to_bytes_array()
elif "prompt" in request.parameters:
query = request.parameters["prompt"]
else:
await response_sender.send(error=f"invalid request {request}", final=True)
return
if "sampling_params" in request.parameters:
sampling_params = json.loads(
request.parameters["sampling_params"].removeprefix("JSON:")
)
if "max_tokens" in request.inputs:
request_output_len = request.inputs["max_tokens"]
elif "max_tokens" in sampling_params:
request_output_len = numpy.array(
[[sampling_params["max_tokens"]]], dtype=numpy.int32
)
streaming = request.parameters.get("streaming", False)
input_ids, input_lengths = await self._preprocess(query)
print(input_ids, input_lengths)
prefill_inputs["input_ids"] = input_ids
prefill_inputs["input_lengths"] = input_lengths
prefill_inputs["request_output_len"] = request_output_len
"""Prefill"""
prefill_parameters = {}
prefill_parameters["request_type"] = "context_only"
self._logger.debug(
f"Executing request on context worker with inputs: {prefill_inputs}"
)
async for prefill_response in await self._prefill.async_infer(
inputs=prefill_inputs,
parameters=prefill_parameters,
):
self._logger.debug(f"Prefill response completed: {prefill_response}")
output_ids = numpy.from_dlpack(prefill_response.outputs["output_ids"])
self._logger.debug(f"Output IDs: {output_ids}")
if streaming:
tasks = asyncio.create_task(
self._send_llm_response(
prefill_response, response_sender, final=False
)
)
background_tasks.append(tasks)
"""Decode"""
decode_parameters = {}
decode_parameters["request_type"] = "generation_only"
decode_inputs = {}
decode_inputs["context_phase_params"] = prefill_response.outputs[
"context_phase_params"
]
decode_inputs["input_ids"] = input_ids
decode_inputs["input_lengths"] = input_lengths
decode_inputs["request_output_len"] = request_output_len
async for decode_response in await self._decode.async_infer(
inputs=decode_inputs,
parameters=decode_parameters,
):
self._logger.debug(f"Decode response completed: {decode_response}")
background_tasks.append(
asyncio.create_task(
self._send_llm_response(
decode_response,
response_sender,
final=decode_response.final,
)
)
)
try:
results = await asyncio.gather(*background_tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
self._logger.exception(
f"Sending response failed with exception: {result}"
)
else:
self._logger.debug(f"Response sent successfully: {result}")
except Exception as e:
self._logger.exception(f"Error during response sending: {e}")
for output in prefill_response.outputs:
del output
for output in decode_response.outputs:
del output
async def _preprocess(self, query):
start_ids = None
start_lengths = None
print("here!")
if isinstance(query, str):
query = [[query]]
async for preprocess_response in self._preprocess_model.async_infer(
inputs={"query": query}
):
self._logger.debug(f"Preprocess response completed: {preprocess_response}")
start_ids = numpy.from_dlpack(preprocess_response.outputs["start_ids"])
start_lengths = numpy.from_dlpack(
preprocess_response.outputs["start_lengths"]
)
return start_ids, start_lengths
async def _postprocessing(self, tokens_batch, sequence_lengths):
outputs = []
async for postprocess_response in self._postprocess_model.async_infer(
inputs={"tokens_batch": tokens_batch, "sequence_lengths": sequence_lengths}
):
self._logger.debug(f"Received postprocess response: {postprocess_response}")
output = postprocess_response.outputs["output"].to_string_array()
outputs.append(output)
return outputs
async def _send_llm_response(self, llm_response, response_sender, final):
tokens_batch = numpy.from_dlpack(llm_response.outputs["output_ids"])
self._logger.debug(f"Output ids length: {tokens_batch}")
sequence_length = numpy.from_dlpack(llm_response.outputs["sequence_length"])
output = await self._postprocessing(tokens_batch, sequence_length)
store_outputs_in_response = set()
if self._store_outputs_in_response:
store_outputs_in_response.add("text_output")
await response_sender.send(
outputs={"text_output": output[0]},
final=final,
store_outputs_in_response=store_outputs_in_response,
)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import threading
import time
import numpy as np
import triton_python_backend_utils as pb_utils
DEFAULT_OUTPUT_LEN = 1000
class TritonPythonModel:
def initialize(self, args):
self._logger = pb_utils.Logger
model_config = json.loads(args["model_config"])
self._generate_token_latency = (
float(
model_config["parameters"]["generate_token_latency_ms"]["string_value"]
)
) / 1000
self._context_token_latency = (
float(
model_config["parameters"]["context_token_latency_ms"]["string_value"]
)
) / 1000
using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
model_config
)
if not using_decoupled:
raise pb_utils.TritonModelException(
"""the model `{}` can generate any number of responses per request,
enable decoupled transaction policy in model configuration to
serve this model""".format(
args["model_name"]
)
)
for output_name in ["output_ids", "sequence_length", "context_phase_params"]:
setattr(
self,
output_name.lower() + "_dtype",
pb_utils.triton_string_to_numpy(
pb_utils.get_output_config_by_name(model_config, output_name)[
"data_type"
]
),
)
# To keep track of response threads so that we can delay
# the finalizing the model until all response threads
# have completed.
self.inflight_thread_count = 0
self.inflight_thread_count_lck = threading.Lock()
def response_thread(self, response_sender, inputs):
streaming = inputs["streaming"][0]
request_type = inputs["request_type"]
output_ids = []
output_sequence_length = inputs["request_output_len"][0]
self._logger.log_verbose(
f"Starting Response Thread: {threading.get_native_id()}"
)
self._logger.log_verbose(f"Inputs: {inputs}")
self._logger.log_verbose(f"Streaming: {streaming}")
self._logger.log_verbose(f"Request Type: {request_type}")
input_sequence_length = inputs["input_lengths"][0]
if inputs["request_type"] != "generate_only":
for _ in inputs["input_ids"][0]:
time.sleep(self._context_token_latency)
if request_type != "context_only":
for index in range(output_sequence_length):
output_ids.append(inputs["input_ids"][0][index % input_sequence_length])
if streaming:
output_ids_tensor = pb_utils.Tensor(
"output_ids",
np.array([[[output_ids[-1]]]]).astype(self.output_ids_dtype),
)
sequence_length_tensor = pb_utils.Tensor(
"sequence_length",
np.array([[1]]).astype(self.sequence_length_dtype),
)
response = pb_utils.InferenceResponse(
output_tensors=[output_ids_tensor, sequence_length_tensor]
)
flags = (
pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
if index == output_sequence_length - 1
else 0
)
response_sender.send(response, flags=flags)
time.sleep(self._generate_token_latency)
if not streaming:
output_ids_tensor = pb_utils.Tensor(
"output_ids", np.array([[output_ids]]).astype(self.output_ids_dtype)
)
sequence_length_tensor = pb_utils.Tensor(
"sequence_length",
np.array([[output_sequence_length]]).astype(
self.sequence_length_dtype
),
)
response = pb_utils.InferenceResponse(
output_tensors=[output_ids_tensor, sequence_length_tensor]
)
response_sender.send(
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
)
if request_type == "context_only":
output_ids.append(inputs["input_ids"][0][0])
output_ids_tensor = pb_utils.Tensor(
"output_ids",
np.array([[output_ids]]).astype(self.output_ids_dtype),
)
sequence_length_tensor = pb_utils.Tensor("sequence_length", np.array([[1]]))
context_phase_params = pb_utils.Tensor(
"context_phase_params",
np.array([[1, 2, 3, 4]]).astype(self.context_phase_params_dtype),
)
response = pb_utils.InferenceResponse(
output_tensors=[
output_ids_tensor,
sequence_length_tensor,
context_phase_params,
]
)
response_sender.send(
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
)
# We must close the response sender to indicate to Triton that we are
# done sending responses for the corresponding request. We can't use the
# response sender after closing it. The response sender is closed by
# setting the TRITONSERVER_RESPONSE_COMPLETE_FINAL.
# response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
with self.inflight_thread_count_lck:
self.inflight_thread_count -= 1
self._logger.log_verbose(
f"Exiting Response Thread: {threading.get_native_id()}"
)
def _get_inputs(self, request):
inputs = [
"context_phase_params",
"streaming",
"min_length",
"request_output_len",
"input_lengths",
"input_ids",
]
result = {}
for input_ in inputs:
value = pb_utils.get_input_tensor_by_name(request, input_)
if value is not None:
result[input_] = value.as_numpy()
input_parameters = json.loads(request.parameters())
if "request_type" in input_parameters:
result["request_type"] = input_parameters["request_type"]
else:
result["request_type"] = "aggregate"
if "request_output_len" not in inputs:
result["request_output_len"] = DEFAULT_OUTPUT_LEN
if "streaming" not in result:
result["streaming"] = [False]
return result
def execute(self, requests):
for idx, request in enumerate(requests):
inputs = self._get_inputs(request)
# Start a separate thread to send the responses for the request. The
# sending back the responses is delegated to this thread.
thread = threading.Thread(
target=self.response_thread,
args=(request.get_response_sender(), inputs),
)
# A model using decoupled transaction policy is not required to send all
# responses for the current request before returning from the execute.
# To demonstrate the flexibility of the decoupled API, we are running
# response thread entirely independent of the execute thread.
thread.daemon = True
with self.inflight_thread_count_lck:
self.inflight_thread_count += 1
thread.start()
return None
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Emulates the tensorrt_llm config from:
# https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt
backend: "python"
max_batch_size: ${triton_max_batch_size}
model_transaction_policy {
decoupled: true
}
dynamic_batching {
preferred_batch_size: [ ${triton_max_batch_size} ]
max_queue_delay_microseconds: 0
default_queue_policy: { max_queue_size: 0 }
}
parameters: {
key: "context_token_latency_ms"
value: {
string_value: "${context_token_latency_ms}"
}
}
parameters: {
key: "generate_token_latency_ms"
value: {
string_value: "${generate_token_latency_ms}"
}
}
input [
{
name: "input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
allow_ragged_batch: true
},
{
name: "encoder_input_features"
data_type: TYPE_FP16
dims: [ -1, -1 ]
allow_ragged_batch: true
optional: true
},
{
name: "encoder_output_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "input_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
},
{
name: "request_output_len"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
},
{
name: "num_return_sequences"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "draft_input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "decoder_input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "decoder_input_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
reshape: { shape: [ ] }
},
{
name: "draft_logits"
data_type: TYPE_FP32
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "draft_acceptance_threshold"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "end_id"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "pad_id"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "stop_words_list"
data_type: TYPE_INT32
dims: [ 2, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "bad_words_list"
data_type: TYPE_INT32
dims: [ 2, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "embedding_bias"
data_type: TYPE_FP32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "beam_width"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "temperature"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_k"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_min"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_decay"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_reset_ids"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "len_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "early_stopping"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "repetition_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "min_length"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "beam_search_diversity_rate"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "presence_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "frequency_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "random_seed"
data_type: TYPE_UINT64
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_log_probs"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_context_logits"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_generation_logits"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_kv_cache_reuse_stats"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "exclude_input_in_output"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "stop"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "streaming"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "prompt_embedding_table"
data_type: TYPE_FP16
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "prompt_table_extra_ids"
data_type: TYPE_UINT64
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "prompt_vocab_size"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
# cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
{
name: "cross_attention_mask"
data_type: TYPE_BOOL
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
# the unique task ID for the given LoRA.
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
{
name: "lora_task_id"
data_type: TYPE_UINT64
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
# each of the in / out tensors are first flattened and then concatenated together in the format above.
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
{
name: "lora_weights"
data_type: TYPE_FP16
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
# module identifier (same size a first dimension of lora_weights)
# See LoraModule::ModuleType for model id mapping
#
# "attn_qkv": 0 # compbined qkv adapter
# "attn_q": 1 # q adapter
# "attn_k": 2 # k adapter
# "attn_v": 3 # v adapter
# "attn_dense": 4 # adapter for the dense layer in attention
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
#
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
{
name: "lora_config"
data_type: TYPE_INT32
dims: [ -1, 3 ]
optional: true
allow_ragged_batch: true
},
{
name: "context_phase_params"
data_type: TYPE_UINT8
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
# skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
{
name: "skip_cross_attn_blocks"
data_type: TYPE_BOOL
dims: [ 1 ]
optional: true
allow_ragged_batch: true
}
]
output [
{
name: "output_ids"
data_type: TYPE_INT32
dims: [ -1, -1 ]
},
{
name: "sequence_length"
data_type: TYPE_INT32
dims: [ -1 ]
},
{
name: "cum_log_probs"
data_type: TYPE_FP32
dims: [ -1 ]
},
{
name: "output_log_probs"
data_type: TYPE_FP32
dims: [ -1, -1 ]
},
{
name: "context_logits"
data_type: TYPE_FP32
dims: [ -1, -1 ]
},
{
name: "generation_logits"
data_type: TYPE_FP32
dims: [ -1, -1, -1 ]
},
{
name: "batch_index"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "sequence_index"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "context_phase_params"
data_type: TYPE_UINT8
dims: [ -1 ]
},
{
name: "kv_cache_alloc_new_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "kv_cache_reused_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "kv_cache_alloc_total_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
}
]
# Add more parameters as per requirement
instance_group [
{
count: 1
kind : KIND_CPU
}
]
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
class TritonPythonModel:
"""
This model allows Triton to act like a api server for T3 ICP
"""
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = [
{"name": "tokens_batch", "data_type": "TYPE_INT32", "dims": [-1, -1]},
{"name": "sequence_lengths", "data_type": "TYPE_INT32", "dims": [-1]},
]
outputs = [
{"name": "output", "data_type": "TYPE_STRING", "dims": [-1]},
]
# Store the model configuration as a dictionary.
config = auto_complete_model_config.as_dict()
input_names = []
output_names = []
for input in config["input"]:
input_names.append(input["name"])
for output in config["output"]:
output_names.append(output["name"])
# Add only missing inputs and output to the model configuration.
for input in inputs:
if input["name"] not in input_names:
auto_complete_model_config.add_input(input)
for output in outputs:
if output["name"] not in output_names:
auto_complete_model_config.add_output(output)
return auto_complete_model_config
def initialize(self, args):
model_config = json.loads(args["model_config"])
self.logger = pb_utils.Logger
# Parse model configs
model_config = json.loads(args["model_config"])
tokenizer_dir = model_config["parameters"]["tokenizer_dir"]["string_value"]
skip_special_tokens = model_config["parameters"].get("skip_special_tokens")
if skip_special_tokens is not None:
skip_special_tokens_str = skip_special_tokens["string_value"].lower()
if skip_special_tokens_str in [
"true",
"false",
"1",
"0",
"t",
"f",
"y",
"n",
"yes",
"no",
]:
self.skip_special_tokens = skip_special_tokens_str in [
"true",
"1",
"t",
"y",
"yes",
]
else:
self.logger.log_warn(
f"[TensorRT-LLM][WARNING] Don't setup 'skip_special_tokens' correctly (set value is {skip_special_tokens['string_value']}). Set it as True by default."
)
self.skip_special_tokens = True
else:
self.logger.log_warn(
"[TensorRT-LLM][WARNING] Don't setup 'skip_special_tokens'. Set it as True by default."
)
self.skip_special_tokens = True
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, legacy=False, padding_side="left", trust_remote_code=True
)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
for output_name in ["output"]:
setattr(
self,
output_name.lower() + "_dtype",
pb_utils.triton_string_to_numpy(
pb_utils.get_output_config_by_name(model_config, output_name)[
"data_type"
]
),
)
def execute(self, requests):
tokens_batch = []
sequence_lengths = []
for idx, request in enumerate(requests):
for input_tensor in request.inputs():
if input_tensor.name() == "tokens_batch":
tokens_batch.append(input_tensor.as_numpy())
elif input_tensor.name() == "sequence_lengths":
sequence_lengths.append(input_tensor.as_numpy())
else:
raise ValueError(f"unknown input {input_tensor.name}")
# batch decode
list_of_tokens = []
req_idx_offset = 0
req_idx_offsets = [req_idx_offset]
for idx, token_batch in enumerate(tokens_batch):
for batch_idx, beam_tokens in enumerate(token_batch):
for beam_idx, tokens in enumerate(beam_tokens):
seq_len = sequence_lengths[idx][batch_idx][beam_idx]
list_of_tokens.append(tokens[:seq_len])
req_idx_offset += 1
req_idx_offsets.append(req_idx_offset)
all_outputs = self.tokenizer.batch_decode(
list_of_tokens, skip_special_tokens=self.skip_special_tokens
)
# construct responses
responses = []
for idx, request in enumerate(requests):
req_outputs = [
x.encode("utf8")
for x in all_outputs[req_idx_offsets[idx] : req_idx_offsets[idx + 1]]
]
output_tensor = pb_utils.Tensor(
"output", np.array(req_outputs).astype(self.output_dtype)
)
outputs = [output_tensor]
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
responses.append(inference_response)
return responses
def finalize(self):
"""`finalize` is called only once when the model is being unloaded.
Implementing `finalize` function is optional. This function allows
the model to perform any necessary clean ups before exit.
"""
print("Cleaning up...")
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
backend: "python"
# TODO: Tune dynamic batcher
max_batch_size: 64
dynamic_batching {}
parameters {
key: "tokenizer_dir"
value: {
string_value: "/workspace/examples/llm/tensorrtllm/operators/hf_downloads/llama-3.1-8b-instruct"
}
}
#parameters {
# key: "skip_special_tokens"
# value: {
# string_value: "${skip_special_tokens}"
# }
#}
instance_group [
{
count: 10
kind : KIND_CPU
}
]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment