Commit 0bfd9a76 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

refactor: remove python native runtime

parent 8f741f14
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from frontend.fastapi_frontend import FastApiFrontend
from llm.api_server.triton_distributed_engine import TritonDistributedEngine
from triton_distributed.runtime.logger import get_logger
from .parser import parse_args
def main(args):
print(args)
logger = get_logger(args.log_level, args.program_name)
logger.info("Starting")
# Wrap Triton Distributed in an interface-conforming "LLMEngine"
engine: TritonDistributedEngine = TritonDistributedEngine(
nats_url=args.request_plane_uri,
data_plane_host=args.data_plane_host,
data_plane_port=args.data_plane_port,
model_name=args.model_name,
tokenizer=args.tokenizer,
backend=args.backend,
)
# Attach TritonLLMEngine as the backbone for inference and model management
openai_frontend: FastApiFrontend = FastApiFrontend(
engine=engine,
host=args.api_server_host,
port=args.api_server_port,
log_level=args.log_level,
)
# Blocking call until killed or interrupted with SIGINT
openai_frontend.start()
if __name__ == "__main__":
parser, args = parse_args()
args.program_name = parser.prog
main(args)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from llm.api_server.connector import (
BaseTriton3Connector,
InferenceRequest,
InferenceResponse,
TritonInferenceError,
)
# FIXME: Integrate better with api_server library
from schemas.openai import (
ChatCompletionChoice,
ChatCompletionFinishReason,
ChatCompletionResponseMessage,
ChatCompletionStreamingResponseChoice,
ChatCompletionStreamResponseDelta,
CreateChatCompletionRequest,
CreateChatCompletionResponse,
CreateChatCompletionStreamResponse,
ObjectType,
)
from transformers import AutoTokenizer
LOGGER = logging.getLogger(__name__)
"""
Example request with curl
curl -X 'POST' \\
'http://{host}:{port}/v1/chat/completions' \\
-H 'accept: application/json' \\
-H 'Content-Type: application/json' \\
-d '{{
"model": "{model}",
"messages": [
{{
"role":"user",
"content":"Hello! How are you?"
}},
{{
"role":"assistant",
"content":"Hi! I am quite well, how can I help you today?"
}},
{{
"role":"user",
"content":"Can you write me a song?"
}}
],
"top_p": 1,
"n": 1,
"max_tokens": 15,
"stream": true,
"frequency_penalty": 1.0,
"stop": ["hello"]
}}'
"""
def generate_sampling_params(
request: CreateChatCompletionRequest,
non_supported_params_none: Optional[List[str]] = None,
) -> Dict[str, Any]:
errors_message = ""
if not non_supported_params_none:
non_supported_params_none = []
for param in non_supported_params_none:
if getattr(request, param, None) is not None:
errors_message += f"The parameter '{param}' is not supported. "
if errors_message:
raise ValueError(errors_message)
sampling_params = {}
if request.temperature is not None:
sampling_params["temperature"] = request.temperature
if request.n is not None:
sampling_params["n"] = request.n
if request.top_p is not None:
sampling_params["top_p"] = request.top_p
if request.presence_penalty is not None:
sampling_params["presence_penalty"] = request.presence_penalty
if request.frequency_penalty is not None:
sampling_params["frequency_penalty"] = request.frequency_penalty
if request.max_tokens is not None:
sampling_params["max_tokens"] = request.max_tokens
if request.min_tokens is not None:
sampling_params["min_tokens"] = request.min_tokens
if request.stop is not None:
sampling_params["stop"] = request.stop
if request.seed is not None:
sampling_params["seed"] = request.seed
return sampling_params
def create_chat_response(
request_id: str,
model: str,
model_output: Union[np.ndarray, List[str]],
role: str,
prompt: str,
) -> CreateChatCompletionResponse:
"""Create chunk responses from the detokenized outputs for non-streaming completions"""
detokenized_outputs = model_output
# Extract prompt from detokenized_outputs
cleaned_outputs = []
for detokenized_output in detokenized_outputs:
# FIXME - should be receiving array of strings by this point, not nested arrays
if isinstance(detokenized_output, np.ndarray):
detokenized_output = str(detokenized_output[0])
if not isinstance(detokenized_output, str):
raise RuntimeError(
f"ERROR: detokenized_output is not a string! {type(detokenized_output)=} | {detokenized_output=}"
)
# FIXME: Should this be handled by 'echo' param instead?
if detokenized_output.startswith(prompt):
cleaned_output = detokenized_output[len(prompt) :]
else:
cleaned_output = detokenized_output
cleaned_outputs.append(cleaned_output)
messages = [
ChatCompletionResponseMessage(role=role, content=output_str)
for output_str in cleaned_outputs
]
choices = [
ChatCompletionChoice(
index=idx,
message=message,
finish_reason=ChatCompletionFinishReason.stop,
logprobs=None,
)
for idx, message in enumerate(messages)
]
chat_response = CreateChatCompletionResponse(
id=request_id,
choices=choices,
created=int(time.time()),
model=model,
system_fingerprint=None,
object=ObjectType.chat_completion,
)
return chat_response
def generate_delta(
output_str: str, role: str, previous_output: Optional[str] = None
) -> ChatCompletionStreamResponseDelta:
"""Generate the delta from the output string
Args:
output_str (str): The output string from the model.
role (str): The role of the AI generating the output.
previous_output (Optional[str]): The previous output string. Defaults to None.
Example:
print(generate_delta("user: Hello!, assistant: Hi!", "assistant", "user: Hello!, assistant: "))
# Output: Delta(role='assistant', content='Hi!')
"""
if previous_output is None:
return ChatCompletionStreamResponseDelta(role=role, content=output_str)
else:
# FIXME: Should we be manually finding the delta here? Or full text from last response?
delta_start = output_str.find(previous_output)
if delta_start == -1:
LOGGER.warning(
f"Previous output \n<START>\n{previous_output}\n<END>\n not found in the output string: \n<START>\n{output_str}\n<END>\n"
)
return ChatCompletionStreamResponseDelta(role=role, content=output_str)
delta = output_str[delta_start + len(previous_output) :]
return ChatCompletionStreamResponseDelta(role=role, content=delta)
def create_chunk_responses(
request_id: str,
model: str,
model_output: List[str],
role: str,
previous_output: Optional[list[str]] = None,
logprobs: Optional[np.ndarray] = None,
finish_reason: Optional[str] = None,
) -> Tuple[CreateChatCompletionStreamResponse, List[str]]:
"""Create chunk responses from the detokenized outputs for streaming completions.
Function extracts the delta from the output string and creates a chunk response. It also updates the previous output.
Args:
request_id (str): The unique identifier for the request.
model (str): The model used for generating the response.
model_output (List[str]): The list of output strings from the model.
role (str): The role of the AI generating the output.
previous_output (Optional[str]): The previous output string. Defaults to None.
logprobs (Optional[np.ndarray]): The log probabilities of the output tokens. Defaults to None.
finish_reason (Optional[str]): The reason for stopping the completion. Defaults to None.
Returns:
Tuple[CreateChatCompletionStreamResponse, List[str]]: A tuple containing the chunk response and the new previous output.
Example:
create_chunk_responses(
request_id="chatcmpl-123",
model="gpt-4o-mini",
model_output=["I am fine, thank you!", "How can I help you?"],
role="assistant",
previous_output="user: Hello!, assistant: ",
logprobs=np.array([[0.1, 0.2, 0.3, 0.4, 0.5]]),
finish_reason="stop"
))
"""
detokenized_outputs = model_output
new_previous_output = []
deltas = []
for idx, output_str in enumerate(detokenized_outputs):
if previous_output is not None:
previous_output_row = previous_output[idx]
else:
previous_output_row = None
delta = generate_delta(
output_str=output_str, role=role, previous_output=previous_output_row
)
deltas.append(delta)
new_previous_output.append(output_str)
choices = []
for idx, delta in enumerate(deltas):
choice_kwargs = {
"index": idx,
"delta": delta,
# FIXME: Validate finish_reason behavior on first vs last responses
"finish_reason": finish_reason,
"logprobs": None,
}
if logprobs is not None:
choice_kwargs["logprobs"] = logprobs[idx]
chunk_choice = ChatCompletionStreamingResponseChoice(**choice_kwargs)
choices.append(chunk_choice)
chunk_response = CreateChatCompletionStreamResponse(
id=request_id,
object=ObjectType.chat_completion_chunk,
created=int(time.time()),
model=model,
system_fingerprint=request_id,
choices=choices,
)
return chunk_response, new_previous_output
class ChatHandler:
def __init__(self, triton_connector: BaseTriton3Connector, tokenizer: str):
self._triton_connector = triton_connector
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
def translate_chat_inputs(
self, request: CreateChatCompletionRequest, request_id: str, prompt: str
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
raise NotImplementedError("This method is not implemented yet")
def translate_chat_outputs(
self, response: InferenceResponse, model_name: str
) -> Dict[str, Any]:
raise NotImplementedError("This method is not implemented yet")
def stream_response_adaptor(self, response_stream):
async def adaptor_stream():
async for response in response_stream():
if isinstance(response, Exception):
yield self.exception_adaptor(response).body
else:
yield response.model_dump_json() + "\n"
return StreamingResponse(adaptor_stream(), media_type="application/json")
def response_adaptor(self, response):
return response.model_dump_json()
def exception_adaptor(self, exception):
return JSONResponse(
content={"error": str(exception), "code": 500}, status_code=500
)
async def process_request(self, request: Any, raw_request: Optional[Request]):
request_id = str(uuid.uuid4())
LOGGER.debug(f"{request=}")
prompt, role = self._create_prompt(request)
inputs, parameters = self.translate_chat_inputs(request, request_id, prompt)
triton_request = InferenceRequest(inputs=inputs, parameters=parameters)
# Streaming
if request.stream:
response_stream = self._stream_response_factory(
request_id, request.model, triton_request, prompt, role
)
return self.stream_response_adaptor(response_stream)
# Non-Streaming
response_data = None
try:
chat_outputs = None
async for response in self._triton_connector.inference(
request.model, triton_request
):
chat_outputs = self.translate_chat_outputs(response, request.model)
kwargs = {
"request_id": request_id,
"model": request.model,
"role": role,
"prompt": prompt,
}
if chat_outputs is not None:
kwargs.update(chat_outputs)
response_data = create_chat_response(**kwargs)
except TritonInferenceError as e:
logging.error(f"Error processing chat completion request: {e}")
return self.exception_adaptor(e)
LOGGER.info(f"Chat completion response: {response_data}")
return self.response_adaptor(response_data)
def _stream_response_factory(self, request_id, model, triton_request, prompt, role):
async def stream_response():
try:
previous_output = None
async for response in self._triton_connector.inference(
model, triton_request
):
# FIXME: Detect stop in response
try:
chat_outputs = self.translate_chat_outputs(response, model)
except KeyError as e:
LOGGER.info(f"KeyError {e} in response: {response}")
break
model_output = chat_outputs["model_output"]
chunk_response, new_previous_output = create_chunk_responses(
request_id=request_id,
model=model,
model_output=model_output,
role=role,
previous_output=previous_output,
)
previous_output = new_previous_output
yield f"data: {chunk_response.model_dump_json(exclude_unset=True)}\n\n"
except TritonInferenceError as e:
logging.error(f"Error processing chat completion request: {e}")
# FIXME: Does this need to conform to SSE standard for errors?
yield JSONResponse(
content={"error": str(e), "code": 500}, status_code=500
).body
finally:
yield "data: [DONE]\n\n"
return stream_response
# FIXME: Use shared/common module for these functions between
# TritonLLMEngine and TritonDistributedEngine implementations.
def _get_first_response_role(
self, conversation: List[Dict], add_generation_prompt: bool, default_role: str
) -> str:
if add_generation_prompt:
return default_role
return conversation[-1]["role"]
def _create_prompt(self, request: CreateChatCompletionRequest) -> Tuple[str, str]:
"""Create a prompt for vLLM model from the messages"""
conversation = [
message.model_dump(exclude_none=True) for message in request.messages
]
add_generation_prompt = True
prompt = self.tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
LOGGER.debug(f"{prompt=}")
default_role = "assistant"
role = self._get_first_response_role(
conversation, add_generation_prompt, default_role
)
return prompt.strip(), role
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
import numpy
import numpy as np
from llm.api_server.chat import ChatHandler, generate_sampling_params
from llm.api_server.connector import BaseTriton3Connector, InferenceResponse
from schemas.openai import CreateChatCompletionRequest
LOGGER = logging.getLogger(__name__)
# FIXME: Share request conversion logic where applicable
def generate_sampling_params_vllm(
request: CreateChatCompletionRequest,
non_supported_params: Optional[List[str]] = None,
) -> dict:
"""
Generate sampling params for vLLM from the request.
Args:
request: CreateChatCompletionRequest object.
Returns:
dict: Sampling params for vLLM.
"""
errors_message = ""
if request.logprobs:
errors_message += "The parameter 'logprobs' set to True is not supported. "
if request.tools and request.tools.type != "text":
errors_message += (
f"The parameter 'tools' type {request.tools.type} is not supported. "
)
if errors_message:
raise ValueError(errors_message)
if non_supported_params is None:
non_supported_params = [
"logit_bias",
"top_logprobs",
"tool_choice",
"user",
"service_tier",
]
sampling_params = generate_sampling_params(request, non_supported_params)
# NOTE: vLLM parameters (ex: top_k) not supported until added to schema
return sampling_params
class ChatHandlerTensorrtLLM(ChatHandler):
def __init__(
self, triton_connector: BaseTriton3Connector, model_name: str, tokenizer: str
):
super().__init__(triton_connector, tokenizer)
self._model_name = model_name
def translate_chat_inputs(
self, request: CreateChatCompletionRequest, request_id: str, prompt: str
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""Translate the chat completion request to inference request"""
if self._model_name is not None and self._model_name != request.model:
raise ValueError(
f"Model name mismatch: {self._model_name} != {request.model}"
)
inputs: Dict[str, np.ndarray | Any] = {}
sampling_params = generate_sampling_params_vllm(request)
parameters = {
"sampling_params": sampling_params,
"request_id": request_id,
# "prompt": prompt,
}
inputs["text_input"] = [[prompt]]
inputs["max_tokens"] = numpy.array(
[[sampling_params["max_tokens"]]], dtype=numpy.int32
)
return inputs, parameters
def translate_chat_outputs(
self, response: InferenceResponse, model_name: str
) -> Dict[str, Any]:
"""Translate the inference outputs to chat completion response"""
if "text" in response.parameters:
return {"model_output": [response.parameters["text"]]}
elif "text_output" in response.outputs:
print(response.outputs["text_output"])
return {"model_output": response.outputs["text_output"][0]}
return {}
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from llm.api_server.chat import ChatHandler, generate_sampling_params
from llm.api_server.connector import BaseTriton3Connector, InferenceResponse
from schemas.openai import CreateChatCompletionRequest
LOGGER = logging.getLogger(__name__)
# FIXME: Share request conversion logic where applicable
def generate_sampling_params_vllm(
request: CreateChatCompletionRequest,
non_supported_params: Optional[List[str]] = None,
) -> dict:
"""
Generate sampling params for vLLM from the request.
Args:
request: CreateChatCompletionRequest object.
Returns:
dict: Sampling params for vLLM.
"""
errors_message = ""
if request.logprobs:
errors_message += "The parameter 'logprobs' set to True is not supported. "
if request.tools and request.tools.type != "text":
errors_message += (
f"The parameter 'tools' type {request.tools.type} is not supported. "
)
if errors_message:
raise ValueError(errors_message)
if non_supported_params is None:
non_supported_params = [
"logit_bias",
"top_logprobs",
"tool_choice",
"user",
"service_tier",
]
sampling_params = generate_sampling_params(request, non_supported_params)
# NOTE: vLLM parameters (ex: top_k) not supported until added to schema
return sampling_params
class ChatHandlerVllm(ChatHandler):
def __init__(
self, triton_connector: BaseTriton3Connector, model_name: str, tokenizer: str
):
super().__init__(triton_connector, tokenizer)
self._model_name = model_name
def translate_chat_inputs(
self, request: CreateChatCompletionRequest, request_id: str, prompt: str
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""Translate the chat completion request to inference request"""
if self._model_name is not None and self._model_name != request.model:
raise ValueError(
f"Model name mismatch: {self._model_name} != {request.model}"
)
inputs: Dict[str, np.ndarray] = {}
sampling_params = generate_sampling_params_vllm(request)
parameters = {
"sampling_params": sampling_params,
"request_id": request_id,
"prompt": prompt,
}
return inputs, parameters
def translate_chat_outputs(
self, response: InferenceResponse, model_name: str
) -> Dict[str, Any]:
"""Translate the inference outputs to chat completion response"""
if "text" in response.parameters:
return {"model_output": [response.parameters["text"]]}
elif "text_output" in response.outputs:
return {"model_output": response.outputs["text_output"][0]}
return {}
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import dataclasses
import typing
class TritonInferenceError(Exception):
"""Error occurred during Triton inference."""
@dataclasses.dataclass
class InferenceRequest:
"""Inference request."""
inputs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
parameters: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class InferenceResponse:
"""Inference response."""
outputs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
error: typing.Optional[str] = None
final: bool = False
parameters: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
class BaseTriton3Connector(abc.ABC):
"""Base class for Triton 3 connector."""
@abc.abstractmethod
def inference(
self, model_name: str, request: InferenceRequest
) -> typing.AsyncGenerator[InferenceResponse, None]:
"""Inference request to Triton 3 system.
Args:
model_name: Model name.
request: Inference request.
Returns:
Inference response.
Raises:
TritonInferenceError: error occurred during inference.
"""
raise NotImplementedError
async def list_models(self) -> typing.List[str]:
"""List models available in Triton 3 system.
Returns:
List of model names.
"""
raise NotImplementedError
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description="OpenAI-Compatible API server.", prog="OpenAI API Sever"
)
# API Server
parser.add_argument(
"--api-server-host",
type=str,
required=False,
default="127.0.0.1",
help="API Server host",
)
parser.add_argument(
"--api-server-port",
type=int,
required=False,
default=8000,
help="API Server port",
)
# Request Plane
parser.add_argument(
"--request-plane-uri",
type=str,
required=False,
default="nats://localhost:4223",
help="URL of request plane",
)
# Data Plane
parser.add_argument(
"--data-plane-host",
type=str,
required=False,
default=None,
help="Data plane host",
)
parser.add_argument(
"--data-plane-port",
type=int,
required=False,
default=0,
help="Data plane port. (default: 0 means the system will choose a port)",
)
# Misc
parser.add_argument(
"--tokenizer",
type=str,
required=False,
default="meta-llama/Meta-Llama-3.1-8B-Instruct",
help="Tokenizer to use for chat template in chat completions API",
)
parser.add_argument(
"--model-name",
type=str,
required=False,
default="prefill",
help="Model name",
)
parser.add_argument(
"--log-level",
type=int,
required=False,
default=1,
help="Logging level",
)
parser.add_argument(
"--backend",
type=str,
required=False,
default="vllm",
choices=["vllm", "tensorrtllm"],
help="Backendtype",
)
return parser, parser.parse_args()
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
# UCP data plane causes deadlocks when used more than once, so we use a singleton
_g_singletonic_data_plane = None
_g_singletonic_data_plane_connection_count = 0
_g_actual_host = None
_g_actual_port = None
def set_actual_host_port(host, port):
global _g_actual_host
global _g_actual_port
if _g_singletonic_data_plane is not None:
raise Exception("Cannot set actual host and port after data plane is created")
_g_actual_host = host
_g_actual_port = port
def set_data_plane(data_plane):
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
_g_singletonic_data_plane_connection_count = 1
_g_singletonic_data_plane = data_plane
class RemoteConnector:
"""Handle connection to both request and data planes."""
def __init__(
self,
nats_url: str,
data_plane_host: Optional[str] = None,
data_plane_port: int = 0,
keep_dataplane_endpoints_open: bool = False,
):
"""Initialize RemoteConnector.
Args:
nats_url (str): URL of NATS server.
"""
global _g_singletonic_data_plane
global _g_actual_port
global _g_actual_host
self._nats_url = nats_url
self._request_plane = NatsRequestPlane(nats_url)
if _g_singletonic_data_plane is None:
if _g_actual_host is not None:
data_plane_host = _g_actual_host
if _g_actual_port is not None:
data_plane_port = _g_actual_port
_g_singletonic_data_plane = UcpDataPlane(
hostname=data_plane_host,
port=data_plane_port,
keep_endpoints_open=keep_dataplane_endpoints_open,
)
self._connected = False
self._data_plane = _g_singletonic_data_plane
async def connect(self):
"""Connect to both request and data planes."""
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
assert _g_singletonic_data_plane is not None
await self._request_plane.connect()
if _g_singletonic_data_plane_connection_count == 0:
_g_singletonic_data_plane.connect()
_g_singletonic_data_plane_connection_count += 1
self._connected = True
async def close(self):
"""Disconnect from both request and data planes."""
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
assert _g_singletonic_data_plane is not None
await self._request_plane.close()
_g_singletonic_data_plane_connection_count -= 1
if _g_singletonic_data_plane_connection_count == 0:
_g_singletonic_data_plane.close()
_g_singletonic_data_plane = None
self._data_plane.close()
self._connected = False
async def __aenter__(self):
"""Enter context manager."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager."""
await self.close()
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import typing
from typing import Any, Optional
import numpy as np
from llm.api_server.connector import (
BaseTriton3Connector,
InferenceRequest,
InferenceResponse,
)
from llm.api_server.remote_connector import RemoteConnector
from tritonserver import DataType
from triton_distributed.runtime.remote_operator import RemoteOperator
class RemoteModelConnector(BaseTriton3Connector):
"""Connector for Triton 3 model."""
def __init__(
self,
nats_url: str,
model_name: str,
model_version: Optional[str] = None,
data_plane_host: Optional[str] = None,
data_plane_port: int = 0,
keep_dataplane_endpoints_open: bool = False,
):
"""Initialize Triton 3 connector.
Args:
nats_url: NATS URL (e.g. "localhost:4222").
model_name: Model name.
model_version: Model version. Default is "1".
data_plane_host: Data plane host (e.g. "localhost").
data_plane_port: Data plane port (e.g. 8001). You can use 0 to let the system choose a port.
keep_dataplane_endpoints_open: Keep data plane endpoints open to avoid reconnecting. Default is False.
Example:
remote_model_connector = RemoteModelConnector(
nats_url="localhost:4222",
data_plane_host="localhost",
data_plane_port=8001,
model_name="model_name",
model_version="1",
)
async with remote_model_connector:
request = InferenceRequest(inputs={"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])})
async for response in remote_model_connector.inference(model_name="model_name", request=request):
print(response.outputs)
"""
self._connector = RemoteConnector(
nats_url,
data_plane_host,
data_plane_port,
keep_dataplane_endpoints_open=keep_dataplane_endpoints_open,
)
self._model_name = model_name
if model_version is None:
model_version = "1"
self._model_version = model_version
async def connect(self):
"""Connect to Triton 3 server."""
await self._connector.connect()
self._model = RemoteOperator(
operator=(self._model_name, self._model_version),
request_plane=self._connector._request_plane,
data_plane=self._connector._data_plane,
)
async def close(self):
"""Disconnect from Triton 3 server."""
await self._connector.close()
async def __aenter__(self):
"""Enter context manager."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager."""
await self.close()
async def inference(
self, model_name: str, request: InferenceRequest
) -> typing.AsyncGenerator[InferenceResponse, None]:
"""Inference request to Triton 3 system.
Args:
model_name: Model name.
request: Inference request.
Returns:
Inference response.
Raises:
TritonInferenceError: error occurred during inference.
"""
if not self._connector._connected:
await self.connect()
else:
if self._model_name != model_name:
self._model_name = model_name
self._model = RemoteOperator(
self._model_name,
self._connector._request_plane,
self._connector._data_plane,
)
results = []
for key, value in request.parameters.items():
if isinstance(value, dict):
request.parameters[key] = "JSON:" + json.dumps(value)
store_inputs_in_request = set()
for k, v in request.inputs.items():
store_inputs_in_request.add(k)
results.append(
self._model.async_infer(
inputs=request.inputs,
parameters=request.parameters,
store_inputs_in_request=store_inputs_in_request,
)
)
for result in asyncio.as_completed(results):
responses = await result
outputs = {}
async for response in responses:
for output_name, value in response.outputs.items():
try:
output_value: Any = None
if value.data_type == DataType.BYTES:
output_value = [value.to_string_array()]
else:
output_value = np.from_dlpack(value)
finally:
# FIXME: This is a workaround for the issue that the remote tensor
# is released after connection is closed.
# value.__del__()
pass
outputs[output_name] = output_value
infer_response = InferenceResponse(
outputs=outputs,
final=response.final,
parameters=response.parameters,
)
yield infer_response
async def list_models(self) -> typing.List[str]:
"""List models available in Triton 3 system.
Returns:
List of model names.
"""
# FIXME: Support multiple models
return [self._model_name]
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
pytestmark = pytest.mark.pre_merge
def test_imports():
from engine.engine import LLMEngine as e
from frontend.fastapi_frontend import FastApiFrontend as f
# Placeholder to avoid unused import errors or removal by linters
assert e, f
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import AsyncIterator
from engine.engine import LLMEngine
from llm.api_server.chat_tensorrtllm import ChatHandlerTensorrtLLM
from llm.api_server.chat_vllm import ChatHandlerVllm
from llm.api_server.remote_model_connector import RemoteModelConnector
from schemas.openai import (
CreateChatCompletionRequest,
CreateChatCompletionResponse,
CreateCompletionRequest,
CreateCompletionResponse,
Model,
ObjectType,
)
class TritonDistributedTensorrtLLMChatHandler(ChatHandlerTensorrtLLM):
def __init__(
self, triton_connector: RemoteModelConnector, model_name: str, tokenizer: str
):
super().__init__(triton_connector, model_name, tokenizer)
# Request / response format can vary between frontends, so allow override
# of adaptor functions accordingly.
def stream_response_adaptor(self, response_stream):
async def adaptor_stream():
async for response in response_stream():
if isinstance(response, Exception):
raise response
else:
# Already in SSE String format
yield response
return adaptor_stream
def response_adaptor(self, response):
return response
def exception_adaptor(self, exception):
raise exception
class TritonDistributedChatHandler(ChatHandlerVllm):
def __init__(
self, triton_connector: RemoteModelConnector, model_name: str, tokenizer: str
):
super().__init__(triton_connector, model_name, tokenizer)
# Request / response format can vary between frontends, so allow override
# of adaptor functions accordingly.
def stream_response_adaptor(self, response_stream):
async def adaptor_stream():
async for response in response_stream():
if isinstance(response, Exception):
raise response
else:
# Already in SSE String format
yield response
return adaptor_stream
def response_adaptor(self, response):
return response
def exception_adaptor(self, exception):
raise exception
class TritonDistributedEngine(LLMEngine):
def __init__(
self,
nats_url: str,
data_plane_host: str,
data_plane_port: int,
model_name: str,
tokenizer: str,
backend: str,
):
self.triton_connector = RemoteModelConnector(
nats_url=nats_url,
data_plane_host=data_plane_host,
data_plane_port=data_plane_port,
model_name=model_name,
keep_dataplane_endpoints_open=True,
)
if not backend or backend == "vllm":
# FIXME: Consider supporting multiple or per-model tokenizers
self.request_handler = TritonDistributedChatHandler(
self.triton_connector, model_name, tokenizer
)
else:
self.request_handler = TritonDistributedTensorrtLLMChatHandler(
self.triton_connector, model_name, tokenizer
)
async def chat(
self, request: CreateChatCompletionRequest
) -> CreateChatCompletionResponse | AsyncIterator[str]:
"""
If request.stream is True, this returns an AsyncIterator (or Generator) that
produces server-sent-event (SSE) strings in the following form:
'data: {CreateChatCompletionStreamResponse}\n\n'
...
'data: [DONE]\n\n'
If request.stream is False, this returns a CreateChatCompletionResponse.
"""
# FIXME: Unify call whether streaming or not
if request.stream:
response_generator = await self.request_handler.process_request(
request, None
)
return response_generator()
response = await self.request_handler.process_request(request, None)
return response
async def completion(
self, request: CreateCompletionRequest
) -> CreateCompletionResponse | AsyncIterator[str]:
"""
If request.stream is True, this returns an AsyncIterator (or Generator) that
produces server-sent-event (SSE) strings in the following form:
'data: {CreateCompletionResponse}\n\n'
...
'data: [DONE]\n\n'
If request.stream is False, this returns a CreateCompletionResponse.
"""
raise NotImplementedError
def ready(self) -> bool:
"""
Returns True if the engine is ready to accept inference requests, or False otherwise.
"""
# FIXME: Add more useful checks if available.
return True
def metrics(self) -> str:
"""
Returns the engine's metrics in a Prometheus-compatible string format.
"""
raise NotImplementedError
def models(self) -> list[Model]:
"""
Returns a List of OpenAI Model objects.
"""
# FIXME: Support 'async def models()'
model_names = asyncio.run(self.triton_connector.list_models())
models = [
Model(
id=model_name,
object=ObjectType.model,
owned_by="Triton Distributed",
# FIXME: Need to track creation times, so set to 0 for now.
created=0,
)
for model_name in model_names
]
return models
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Union
import numpy as np
from fastapi import Header, HTTPException
# Utility function to convert response to JSON
def tensor_to_json(tensor: np.ndarray) -> Any:
"""Convert numpy tensor to JSON."""
if tensor.dtype.type is np.bytes_:
items = list([item.decode("utf-8") for item in tensor.flat])
if len(items) == 1:
try:
json_object = json.loads(items[0])
return json_object
except Exception:
return items[0]
return items
return tensor.tolist()
def json_to_tensor(json_list: str) -> np.ndarray:
"""Convert JSON to numpy tensor."""
return np.char.encode(json_list, "utf-8")
def verify_headers(content_type: Union[str, None] = Header(None)):
"""Verify content type."""
if content_type != "application/json":
raise HTTPException(
status_code=415,
detail="Unsupported media type: {content_type}. It must be application/json",
)
<!--
SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Disaggregated Serving with TensorRT-LLM
This example demonstrates **disaggregated serving** [^1] using Triton Distributed together with TensorRT-LLM engines. Disaggregated serving decouples the prefill (prompt encoding) and the decode (token generation) stages of large language model (LLM) inference into separate processes. This separation allows you to independently scale, optimize, and distribute resources for each stage.
In this example, you will deploy
- An **OpenAI-compatible API server** (which receives requests and streams responses).
- One or more **prefill workers** (for encoding the prompt).
- One or more **decode workers** (for generating tokens based on the encoded prompt).
## 1. Prerequisites
1. **GPU Availability**
This setup requires at least two GPUs:
- One GPU is typically used by the **prefill** process.
- Another GPU is used by the **decode** process.
In production systems with heavier loads, you will typically allocate more GPUs across multiple prefill and decode workers.
2. **NATS or Another Coordination Service**
Triton Distributed uses NATS by default for coordination and message passing. Make sure your environment has a running NATS service accessible via a valid `nats://<address>:<port>` endpoint. By default, examples assume `nats://localhost:4223`.
3. **HuggingFace**
- You need a HuggingFace account to download the model and set HF_TOKEN environment variable.
---
## 2. Building the Environment
The example is designed to run in a containerized environment using Triton Distributed, TensorRT-LLM, and associated dependencies. To build the container:
```bash
./container/build.sh --framework tensorrtllm
```
---
## 3. Starting the Deployment
Below is a minimal example of how to start each component of a disaggregated serving setup. The typical sequence is:
1. **Download and build model directories**
2. **Start the Context Worker(s) and Request Plane**
3. **Start the Generate Worker(s)**
1. **Start the API Server** (handles incoming requests and coordinates workers)
All components must be able to connect to the same request plane to coordinate.
### 3.1 Launch Interactive Environment
```bash
./container/run.sh --framework tensorrtllm -it
```
Note: all subsequent commands will be run in the same container for simplicity
Note: by default this command makes all gpu devices visible. Use flag
```
--gpus
```
to selectively make gpu devices visible.
### 3.2: Build model directories
```bash
export HF_TOKEN=<YOUR TOKEN>
python3 /workspace/examples/llm/tensorrtllm/scripts/prepare_models.py --tp-size 1 --model llama-3.1-8b-instruct --max-num-tokens 8192
```
After this you should see the following in `/workspace/examples/llm/tensorrtllm/operators`
```bash
|-- hf_downloads
| `-- llama-3.1-8b-instruct
| |-- config.json
| |-- generation_config.json
| |-- model-00001-of-00004.safetensors
| |-- model-00002-of-00004.safetensors
| |-- model-00003-of-00004.safetensors
| |-- model-00004-of-00004.safetensors
| |-- model.safetensors.index.json
| |-- original
| | `-- params.json
| |-- special_tokens_map.json
| |-- tokenizer.json
| `-- tokenizer_config.json
|-- tensorrtllm_checkpoints
| `-- llama-3.1-8b-instruct
| `-- NVIDIA_H100_NVL
| `-- TP_1
| |-- config.json
| `-- rank0.safetensors
|-- tensorrtllm_engines
| `-- llama-3.1-8b-instruct
| `-- NVIDIA_H100_NVL
| `-- TP_1
| |-- config.json
| `-- rank0.engine
|-- tensorrtllm_models
| `-- llama-3.1-8b-instruct
| `-- NVIDIA_H100_NVL
| `-- TP_1
| |-- context
| | |-- 1
| | | `-- model.py
| | `-- config.pbtxt
| |-- generate
| | |-- 1
| | | `-- model.py
| | `-- config.pbtxt
| |-- llama-3.1-8b-instruct
| | |-- 1
| | `-- config.pbtxt
| |-- postprocessing
| | |-- 1
| | | `-- model.py
| | `-- config.pbtxt
| |-- preprocessing
| | |-- 1
| | | `-- model.py
| | `-- config.pbtxt
| `-- tensorrt_llm
| |-- 1
| | `-- model.py
| `-- config.pbtxt
`-- triton_core_models
|-- mock
| |-- 1
| | `-- model.py
| `-- config.pbtxt
|-- simple_postprocessing
| |-- 1
| | `-- model.py
| `-- config.pbtxt
`-- simple_preprocessing
|-- 1
| `-- model.py
`-- config.pbtxt
```
### 3.3: Deployment Example
To start a basic deployment with 1 prefill and 1 decode worker:
```bash
export MODEL_NAME="llama-3.1-8b-instruct"
python3 /workspace/examples/llm/tensorrtllm/deploy/launch_workers.py \
--context-worker-count 1 \
--generate-worker-count 1 \
--model ${MODEL_NAME} \
--initialize-request-plane \
--disaggregated-serving \
--request-plane-uri ${HOSTNAME}:4222 &
```
Then start the OpenAI compatible API server
```bash
python3 -m llm.api_server \
--tokenizer meta-llama/Llama-3.1-8B-Instruct \
--request-plane-uri ${HOSTNAME}:4222 \
--api-server-host ${HOSTNAME} \
--model-name ${MODEL_NAME} &
```
### 3.4: Sending Requests
Once the API server is running (by default on `localhost:8000`), you can send OpenAI-compatible requests. For example:
```bash
curl ${HOSTNAME}:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama-3.1-8b-instruct",
"messages": [
{"role": "user", "content": "Why is Roger Federer the greatest tennis player of all time?"}
],
"temperature": 0,
"top_p": 0.95,
"max_tokens": 25,
"stream": true,
"n": 1,
"frequency_penalty": 0.0,
"stop": []
}'
```
The above request will return a streamed response with the model’s answer.
## 4. Teardown
To tear down a deployment during local development, you can either kill the
container or the kill the relevant processes involved in the deployment.
To kill the processes being run inside the container, you can run:
```bash
pkill -SIGINT -f python3
pkill -SIGINT -f nats-server
```
You will generally want to make sure you have a clean slate between
deployments to avoid any unexpected errors.
NOTE: If you have other unrelated processes in the environment with `python3`
in the name, the `pkill` command above will terminate them as well. In this
scenario, you could select specific process IDs and use the following command
instead for each process ID replacing `<pid>` below:
```
kill -9 <pid>
```
## Known Issues & Limitations
1. **Tensor Parallelism Constraints**
- Currently limited to TP=1 for both prefill and decode workers
2. Currently streaming is not supported and results are returned all at once.
## References
[^1]: Yinmin Zhong, Shengyu Liu, Junda Chen, Jianbo Hu, Yibo Zhu, Xuanzhe Liu, Xin Jin, and Hao
Zhang. Distserve: Disaggregating prefill and decoding for goodput-optimized large language
model serving. *arXiv:2401.09670v3 [cs.DC]*, 2024.
For more details on Triton Distributed, see the [Hello World example](../../hello_world/) and [Triton Inference Server documentation](https://github.com/triton-inference-server/server).
# KV Aware Routing with TensorRT-LLM
This example also showcase smart routing based on worker KV usage, in aggregated scenario.
To start a KV aware deployment with 2 decode workers:
```bash
export HOSTNAME=localhost
export MODEL_NAME="llama-3.1-8b-instruct"
python3 /workspace/examples/python/llm/tensorrtllm/deploy/launch_workers.py \
--generate-worker-count 2 \
--model ${MODEL_NAME} \
--initialize-request-plane \
--kv-aware-routing \
--request-plane-uri ${HOSTNAME}:4222 &
```
```bash
python3 -m llm.api_server \
--tokenizer meta-llama/Llama-3.1-8B-Instruct \
--request-plane-uri ${HOSTNAME}:4222 \
--api-server-host ${HOSTNAME} \
--model-name ${MODEL_NAME} &
```
```bash
curl ${HOSTNAME}:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama-3.1-8b-instruct",
"messages": [
{"role": "user", "content": "Why is Roger Federer the greatest tennis player of all time? Roger Federer is widely regarded as one of the greatest tennis players of all time, and many consider him the greatest."}
],
"temperature": 0,
"top_p": 0.95,
"max_tokens": 25,
"stream": true,
"n": 1,
"frequency_penalty": 0.0,
"stop": []
}'
```
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import signal
import sys
import time
from pathlib import Path
from llm.tensorrtllm.operators.disaggregated_serving import DisaggregatedServingOperator
from llm.tensorrtllm.operators.kv_aware_routing import KvAwareRoutingOperator
from llm.tensorrtllm.scripts.gpu_info import get_gpu_product_name
from triton_distributed.runtime import (
OperatorConfig,
TritonCoreOperator,
Worker,
WorkerConfig,
)
from .parser import parse_args
deployment = None
def handler(signum, frame):
exit_code = 0
if deployment:
print("Stopping Workers")
exit_code = deployment.stop()
print(f"Workers Stopped Exit Code {exit_code}")
sys.exit(exit_code)
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for sig in signals:
try:
signal.signal(sig, handler)
except Exception:
pass
def _create_disaggregated_serving_op(name, args, max_inflight_requests):
model_repository = str(
Path(args.operator_repository) / "triton_core_models"
) # stores our simple pre/post processing
return OperatorConfig(
name=name,
implementation=DisaggregatedServingOperator,
max_inflight_requests=int(max_inflight_requests),
repository=model_repository,
)
def _create_kv_aware_routing_op(name, args, max_inflight_requests):
model_repository = str(
Path(args.operator_repository) / "triton_core_models"
) # stores our simple pre/post processing
return OperatorConfig(
name=name,
implementation=KvAwareRoutingOperator,
max_inflight_requests=int(max_inflight_requests),
repository=model_repository,
)
def _create_triton_core_op(
name,
max_inflight_requests,
args,
):
# TODO: argparse repo
gpu_name = get_gpu_product_name()
return OperatorConfig(
name=name,
implementation=TritonCoreOperator,
max_inflight_requests=int(max_inflight_requests),
repository=str(
Path(args.operator_repository)
/ "tensorrtllm_models"
/ args.model
/ gpu_name
/ "TP_1"
),
parameters={
"store_outputs_in_response": True,
"config": {
"parameters": {
"participant_ids": {"string_value": f"{args.gpu_device_id}"},
"gpu_device_ids": {"string_value": f"{args.gpu_device_id}"},
"event_buffer_max_size": {"string_value": "1024"},
}
},
},
)
def main(args):
if args.log_dir:
log_dir = Path(args.log_dir)
log_dir.mkdir(exist_ok=True)
worker_configs = []
if args.worker_type == "aggregate":
aggregate_op = _create_triton_core_op(
name=args.model, max_inflight_requests=1000, args=args
)
aggregate = WorkerConfig(
operators=[aggregate_op],
name=args.model,
request_plane_args=([], {"request_plane_uri": args.request_plane_uri}),
metrics_port=args.metrics_port,
)
worker_configs.append(aggregate)
# Context/Generate workers used for Disaggregated Serving
elif args.worker_type == "context":
prefill_op = _create_triton_core_op(
name="context",
max_inflight_requests=1000,
args=args,
)
prefill = WorkerConfig(
operators=[prefill_op],
name="context",
log_level=args.log_level,
metrics_port=args.metrics_port,
request_plane_args=([], {"request_plane_uri": args.request_plane_uri}),
)
worker_configs.append(prefill)
elif args.worker_type == "generate":
decoder_op = _create_triton_core_op(
name="generate",
max_inflight_requests=1000,
args=args,
)
decoder = WorkerConfig(
operators=[decoder_op],
name="generate",
log_level=args.log_level,
metrics_port=args.metrics_port,
request_plane_args=([], {"request_plane_uri": args.request_plane_uri}),
)
worker_configs.append(decoder)
elif args.worker_type == "disaggregated-serving":
prefill_decode_op = _create_disaggregated_serving_op(
name=args.model,
max_inflight_requests=1000,
args=args,
)
prefill_decode = WorkerConfig(
operators=[prefill_decode_op],
name=args.worker_name,
log_level=args.log_level,
metrics_port=args.metrics_port,
request_plane_args=([], {"request_plane_uri": args.request_plane_uri}),
)
worker_configs.append(prefill_decode)
elif args.worker_type == "kv-aware-routing":
print("Creating KvAwareRouting Operator")
router_op = _create_kv_aware_routing_op(
name=args.model,
max_inflight_requests=1000,
args=args,
)
router = WorkerConfig(
operators=[router_op],
name=args.worker_name,
log_level=args.log_level,
metrics_port=args.metrics_port,
request_plane_args=([], {"request_plane_uri": args.request_plane_uri}),
)
worker_configs.append(router)
print("Starting Worker")
for worker_config in worker_configs:
worker = Worker(worker_config)
print(f"worker: {worker}")
worker.start()
print("Worker started ... press Ctrl-C to Exit")
while True:
time.sleep(10)
if __name__ == "__main__":
args = parse_args()
main(args)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import signal
import subprocess
import sys
from pathlib import Path
from llm.tensorrtllm.deploy.parser import parse_args
deployment = None
def handler(signum, frame):
exit_code = 0
if deployment:
print("Stopping Workers")
exit_code = deployment.stop()
print(f"Workers Stopped Exit Code {exit_code}")
sys.exit(exit_code)
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for sig in signals:
try:
signal.signal(sig, handler)
except Exception:
pass
def _launch_mpi_workers(args):
command = [
"mpiexec",
"--allow-run-as-root",
"--oversubscribe",
"--display-map",
"--verbose",
]
if args.log_dir:
WORKER_LOG_DIR = str(Path(args.log_dir) / "workers")
command += ["--output-filename", WORKER_LOG_DIR]
aggregate_gpus = 0
# [TODO] below placements assume model to be TP/PP 1
gpu_count_per_context_worker = 1
gpu_count_per_generate_worker = 1
gpu_count_per_aggreate_worker = 1
for index in range(args.context_worker_count):
starting_gpu = aggregate_gpus
command.extend(_context_cmd(args, index, starting_gpu))
command.append(":")
aggregate_gpus += gpu_count_per_context_worker
for index in range(args.generate_worker_count):
starting_gpu = aggregate_gpus
command.extend(_generate_cmd(args, index, starting_gpu))
command.append(":")
aggregate_gpus += gpu_count_per_generate_worker
for index in range(args.aggregate_worker_count):
starting_gpu = aggregate_gpus
command.extend(_aggregate_cmd(args, index, starting_gpu))
command.append(":")
aggregate_gpus += gpu_count_per_aggreate_worker
command = command[0:-1]
print(" ".join(command))
if args.dry_run:
return
env = os.environ.copy()
return subprocess.Popen(command, env=env, stdin=subprocess.DEVNULL)
def _launch_disagg_model(args):
if not args.disaggregated_serving:
return
starting_gpu = 0
env = os.environ.copy()
command = _disaggregated_serving_cmd(args, starting_gpu)
print(" ".join(command))
if args.dry_run:
return
return subprocess.Popen(command, env=env, stdin=subprocess.DEVNULL)
def _launch_kv_aware_model(args):
if not args.kv_aware_routing:
return
starting_gpu = 0
env = os.environ.copy()
command = _kv_aware_routing_cmd(args, starting_gpu)
print(" ".join(command))
if args.dry_run:
return
return subprocess.Popen(command, env=env, stdin=subprocess.DEVNULL)
def _launch_workers(args):
# Launch nats-server if requested by user for convenience, otherwise
# it can be started separately beforehand.
if args.initialize_request_plane:
_launch_nats_server(args)
# [FIXME] not really related to request plane
_launch_etcd(args)
# Launch TRT-LLM models via mpiexec in the same MPI WORLD
_launch_mpi_workers(args)
# [FIXME] below should be "one of" or merged together
# Launch disaggregated serving "workflow" model to interface
# client-facing requests with Triton Distributed deployment.
_launch_disagg_model(args)
# Launch KV aware routing "workflow" model to interface
# client-facing requests with Triton Distributed deployment.
_launch_kv_aware_model(args)
def _context_cmd(args, index, starting_gpu):
# Hard-coded worker name for internal communication,
# see tensorrtllm.deploy script
worker_name = "context"
command = [
"-np",
"1",
# FIXME: May need to double check this CUDA_VISIBLE_DEVICES
# and trtllm gpu_device_id/participant_id interaction
# "-x",
# f"CUDA_VISIBLE_DEVICES={starting_gpu}",
"python3",
"-m",
"llm.tensorrtllm.deploy",
"--worker-type",
"context",
"--worker-name",
worker_name,
"--model",
args.model,
"--gpu-device-id",
f"{starting_gpu}",
"--metrics-port",
str(50100 + index),
"--initialize-request-plane",
"--request-plane-uri",
f"{os.getenv('HOSTNAME')}:{args.nats_port}",
]
return command
def _generate_cmd(args, index, starting_gpu):
# Hard-coded worker name for internal communication
# see tensorrtllm.deploy script
worker_name = "generate"
command = [
"-np",
"1",
# FIXME: May need to double check this CUDA_VISIBLE_DEVICES
# and trtllm gpu_device_id/participant_id interaction
# "-x",
# f"CUDA_VISIBLE_DEVICES={starting_gpu}",
"python3",
"-m",
"llm.tensorrtllm.deploy",
"--worker-type",
"generate",
"--worker-name",
worker_name,
"--model",
args.model,
"--gpu-device-id",
f"{starting_gpu}",
"--metrics-port",
str(50200 + index),
"--request-plane-uri",
f"{os.getenv('HOSTNAME')}:{args.nats_port}",
]
return command
def _aggregate_cmd(args, index, starting_gpu):
# Hard-coded worker name for internal communication
# see tensorrtllm.deploy script
worker_name = "aggregate"
command = [
"-np",
"1",
# FIXME: May need to double check this CUDA_VISIBLE_DEVICES
# and trtllm gpu_device_id/participant_id interaction
# "-x",
# f"CUDA_VISIBLE_DEVICES={starting_gpu}",
"python3",
"-m",
"llm.tensorrtllm.deploy",
"--worker-type",
"aggregate",
"--worker-name",
worker_name,
"--model",
args.model,
"--gpu-device-id",
f"{starting_gpu}",
"--metrics-port",
str(50300 + index),
"--request-plane-uri",
f"{os.getenv('HOSTNAME')}:{args.nats_port}",
]
return command
def _disaggregated_serving_cmd(args, starting_gpu):
# NOTE: This worker gets the args --worker-name because it will
# receive the API-serving facing requests, and internally handle
# the disaggregation. So this worker name should match the one
# registered to the API Server.
command = [
# FIXME: Does this model need a GPU assigned to it?
# "-x",
# f"CUDA_VISIBLE_DEVICES={starting_gpu}",
"python3",
"-m",
"llm.tensorrtllm.deploy",
"--worker-type",
"disaggregated-serving",
"--metrics-port",
"50002",
"--model",
args.model,
"--worker-name",
args.worker_name,
"--request-plane-uri",
f"{os.getenv('HOSTNAME')}:{args.nats_port}",
]
return command
def _kv_aware_routing_cmd(args, starting_gpu):
# NOTE: This worker gets the args --worker-name because it will
# receive the API-serving facing requests, and internally handle
# the disaggregation. So this worker name should match the one
# registered to the API Server.
command = [
# FIXME: Does this model need a GPU assigned to it?
# "-x",
# f"CUDA_VISIBLE_DEVICES={starting_gpu}",
"python3",
"-m",
"llm.tensorrtllm.deploy",
"--worker-type",
"kv-aware-routing",
"--metrics-port",
"50002",
"--model",
args.model,
"--worker-name",
args.worker_name,
"--request-plane-uri",
f"{os.getenv('HOSTNAME')}:{args.nats_port}",
]
return command
def _launch_nats_server(args, clear_store=True):
# FIXME: Use NatsServer object defined in icp package
store_dir = "/tmp/nats_store"
if clear_store:
shutil.rmtree(store_dir, ignore_errors=True)
command = [
"/usr/local/bin/nats-server",
"--jetstream",
"--port",
str(args.nats_port),
"--store_dir",
store_dir,
]
print(" ".join(command))
if args.dry_run:
return
env = os.environ.copy()
return subprocess.Popen(command, env=env, stdin=subprocess.DEVNULL)
def _launch_etcd(args):
command = [
"/usr/local/bin/etcd",
]
print(" ".join(command))
if args.dry_run:
return
env = os.environ.copy()
return subprocess.Popen(command, env=env, stdin=subprocess.DEVNULL)
if __name__ == "__main__":
args = parse_args()
_launch_workers(args)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
def parse_args():
parser = argparse.ArgumentParser(
description="Run an example of the TensorRT-LLM pipeline."
)
example_dir = Path(__file__).parent.absolute().parent.absolute()
default_operator_repository = example_dir.joinpath("operators")
default_log_dir = ""
parser.add_argument(
"--log-dir",
type=str,
default=str(default_log_dir),
help="log dir folder",
)
parser.add_argument(
"--initialize-request-plane",
default=False,
action="store_true",
help="Initialize the request plane, should only be done once per deployment",
)
parser.add_argument(
"--log-level", type=int, default=1, help="log level applied to all workers"
)
parser.add_argument(
"--request-plane-uri",
type=str,
default="nats://localhost:4222",
help="URI of request plane",
)
parser.add_argument(
"--nats-port",
type=int,
default=4222,
help="Port for NATS server",
)
parser.add_argument(
"--metrics-port",
type=int,
default=50000,
help="Metrics port",
)
parser.add_argument(
"--worker-type",
type=str,
default="aggregate",
help="Type of worker",
choices=[
"aggregate",
"context",
"generate",
"disaggregated-serving",
"kv-aware-routing",
],
)
parser.add_argument("--gpu-device-id", type=int, default=0, help="gpu id")
parser.add_argument(
"--context-worker-count", type=int, default=0, help="Number of context workers"
)
parser.add_argument(
"--generate-worker-count",
type=int,
default=0,
help="Number of generate workers",
)
parser.add_argument(
"--aggregate-worker-count",
type=int,
required=False,
default=0,
help="Number of baseline workers",
)
parser.add_argument(
"--operator-repository",
type=str,
default=str(default_operator_repository),
help="Operator repository",
)
parser.add_argument(
"--worker-name",
type=str,
required=False,
default="llama",
help="Name of the worker",
)
parser.add_argument(
"--model",
type=str,
required=False,
default="llama-3.1-8b-instruct",
choices=[
"mock",
"llama-3.1-70b-instruct",
"llama-3.1-8b-instruct",
"llama-3-8b-instruct-generate",
"llama-3-8b-instruct-context",
"llama-3-8b-instruct",
"llama-3-8b-instruct-default",
"llama-3-70b-instruct-context",
"llama-3-70b-instruct-generate",
"llama-3-70b-instruct",
],
help="model to serve",
)
parser.add_argument(
"--ignore-eos",
action=argparse.BooleanOptionalAction,
required=False,
default=False,
help="Ignore EOS token when generating",
)
parser.add_argument(
"--dry-run",
action=argparse.BooleanOptionalAction,
required=False,
default=False,
help="Dry run the command",
)
parser.add_argument(
"--disaggregated-serving",
action=argparse.BooleanOptionalAction,
required=False,
default=False,
help="Enable disaggregated serving",
)
parser.add_argument(
"--kv-aware-routing",
action=argparse.BooleanOptionalAction,
required=False,
default=False,
help="Enable KV aware routing",
)
return parser.parse_args()
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llm.tensorrtllm.operators.disaggregated_serving import DisaggregatedServingOperator
from llm.tensorrtllm.operators.kv_aware_routing import KvAwareRoutingOperator
__all__ = ["DisaggregatedServingOperator", "KvAwareRoutingOperator"]
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import numpy
from triton_distributed.runtime import (
RemoteInferenceRequest,
RemoteOperator,
TritonCoreOperator,
)
class DisaggregatedServingOperator(TritonCoreOperator):
def __init__(
self,
name,
version,
request_plane,
data_plane,
parameters,
repository,
logger,
triton_core,
):
self._prefill = RemoteOperator("context", request_plane, data_plane)
self._decode = RemoteOperator("generate", request_plane, data_plane)
self._repository = repository
self._triton_core = triton_core
self._triton_core.register_model_repository(repository)
self._preprocess_model = self._triton_core.load("simple_preprocessing")
self._postprocess_model = self._triton_core.load("simple_postprocessing")
self._logger = logger
self._store_outputs_in_response = True
async def execute(self, requests: list[RemoteInferenceRequest]):
self._logger.debug("Executing DisaggregatedServing Request")
background_tasks = []
for request in requests:
task = asyncio.create_task(self._execute_request(request))
background_tasks.append(task)
try:
results = await asyncio.gather(*background_tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
self._logger.exception(
f"Running request execution failed: {result}"
)
else:
self._logger.debug(
f"Request execution completed with result: {result}"
)
except Exception as e:
self._logger.exception(f"Error during request execution: {e}")
async def _execute_request(self, request: RemoteInferenceRequest):
background_tasks = []
prefill_inputs = {}
sampling_params = {}
response_sender = request.response_sender()
"""Preprocessing"""
self._logger.debug(request)
if "text_input" in request.inputs:
query = request.inputs["text_input"].to_bytes_array()
elif "prompt" in request.inputs:
query = request.inputs["prompt"].to_bytes_array()
elif "prompt" in request.parameters:
query = request.parameters["prompt"]
else:
await response_sender.send(error=f"invalid request {request}", final=True)
return
if "sampling_params" in request.parameters:
sampling_params = json.loads(
request.parameters["sampling_params"].removeprefix("JSON:")
)
if "max_tokens" in request.inputs:
request_output_len = request.inputs["max_tokens"]
elif "max_tokens" in sampling_params:
request_output_len = numpy.array(
[[sampling_params["max_tokens"]]], dtype=numpy.int32
)
streaming = request.parameters.get("streaming", False)
input_ids, input_lengths = await self._preprocess(query)
self._logger.debug(input_ids, input_lengths)
prefill_inputs["input_ids"] = input_ids
prefill_inputs["input_lengths"] = input_lengths
prefill_inputs["request_output_len"] = request_output_len
"""Prefill"""
prefill_parameters = {}
prefill_parameters["request_type"] = "context_only"
self._logger.debug(
f"Executing request on context worker with inputs: {prefill_inputs}"
)
async for prefill_response in await self._prefill.async_infer(
inputs=prefill_inputs,
parameters=prefill_parameters,
):
self._logger.debug(f"Prefill response completed: {prefill_response}")
output_ids = numpy.from_dlpack(prefill_response.outputs["output_ids"])
self._logger.debug(f"Output IDs: {output_ids}")
if streaming:
tasks = asyncio.create_task(
self._send_llm_response(
prefill_response, response_sender, final=False
)
)
background_tasks.append(tasks)
"""Decode"""
decode_parameters = {}
decode_parameters["request_type"] = "generation_only"
decode_inputs = {}
decode_inputs["context_phase_params"] = prefill_response.outputs[
"context_phase_params"
]
decode_inputs["input_ids"] = input_ids
decode_inputs["input_lengths"] = input_lengths
decode_inputs["request_output_len"] = request_output_len
async for decode_response in await self._decode.async_infer(
inputs=decode_inputs,
parameters=decode_parameters,
):
self._logger.debug(f"Decode response completed: {decode_response}")
background_tasks.append(
asyncio.create_task(
self._send_llm_response(
decode_response,
response_sender,
final=decode_response.final,
)
)
)
try:
results = await asyncio.gather(*background_tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
self._logger.exception(
f"Sending response failed with exception: {result}"
)
else:
self._logger.debug(f"Response sent successfully: {result}")
except Exception as e:
self._logger.exception(f"Error during response sending: {e}")
for output in prefill_response.outputs:
del output
for output in decode_response.outputs:
del output
async def _preprocess(self, query):
start_ids = None
start_lengths = None
if isinstance(query, str):
query = [[query]]
async for preprocess_response in self._preprocess_model.async_infer(
inputs={"query": query}
):
self._logger.debug(f"Preprocess response completed: {preprocess_response}")
start_ids = numpy.from_dlpack(preprocess_response.outputs["start_ids"])
start_lengths = numpy.from_dlpack(
preprocess_response.outputs["start_lengths"]
)
return start_ids, start_lengths
async def _postprocessing(self, tokens_batch, sequence_lengths):
outputs = []
async for postprocess_response in self._postprocess_model.async_infer(
inputs={"tokens_batch": tokens_batch, "sequence_lengths": sequence_lengths}
):
self._logger.debug(f"Received postprocess response: {postprocess_response}")
output = postprocess_response.outputs["output"].to_string_array()
outputs.append(output)
return outputs
async def _send_llm_response(self, llm_response, response_sender, final):
tokens_batch = numpy.from_dlpack(llm_response.outputs["output_ids"])
self._logger.debug(f"Output ids length: {tokens_batch}")
sequence_length = numpy.from_dlpack(llm_response.outputs["sequence_length"])
output = await self._postprocessing(tokens_batch, sequence_length)
store_outputs_in_response = set()
if self._store_outputs_in_response:
store_outputs_in_response.add("text_output")
await response_sender.send(
outputs={"text_output": output[0]},
final=final,
store_outputs_in_response=store_outputs_in_response,
)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import numpy
from triton_distributed_rs import DistributedRuntime, KvRouter
from triton_distributed.runtime import (
RemoteInferenceRequest,
RemoteOperator,
TritonCoreOperator,
)
class KvAwareRoutingOperator(TritonCoreOperator):
def __init__(
self,
name,
version,
request_plane,
data_plane,
parameters,
repository,
logger,
triton_core,
):
loop = asyncio.get_running_loop()
self._runtime = DistributedRuntime(loop)
backend = self._runtime.namespace("router").component("generate")
self._router = KvRouter(self._runtime, backend)
self._generate = RemoteOperator("generate", request_plane, data_plane)
self._repository = repository
self._triton_core = triton_core
self._triton_core.register_model_repository(repository)
self._preprocess_model = self._triton_core.load("simple_preprocessing")
self._postprocess_model = self._triton_core.load("simple_postprocessing")
self._logger = logger
self._store_outputs_in_response = True
async def execute(self, requests: list[RemoteInferenceRequest]):
self._logger.debug("Executing KvAwareRouting Request")
background_tasks = []
for request in requests:
task = asyncio.create_task(self._execute_request(request))
background_tasks.append(task)
try:
results = await asyncio.gather(*background_tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
self._logger.exception(
f"Running request execution failed: {result}"
)
else:
self._logger.debug(
f"Request execution completed with result: {result}"
)
except Exception as e:
self._logger.exception(f"Error during request execution: {e}")
async def _execute_request(self, request: RemoteInferenceRequest):
background_tasks = []
sampling_params = {}
response_sender = request.response_sender()
"""Preprocessing"""
self._logger.debug(request)
if "text_input" in request.inputs:
query = request.inputs["text_input"].to_bytes_array()
elif "prompt" in request.inputs:
query = request.inputs["prompt"].to_bytes_array()
elif "prompt" in request.parameters:
query = request.parameters["prompt"]
else:
await response_sender.send(error=f"invalid request {request}", final=True)
return
if "sampling_params" in request.parameters:
sampling_params = json.loads(
request.parameters["sampling_params"].removeprefix("JSON:")
)
if "max_tokens" in request.inputs:
request_output_len = request.inputs["max_tokens"]
elif "max_tokens" in sampling_params:
request_output_len = numpy.array(
[[sampling_params["max_tokens"]]], dtype=numpy.int32
)
input_ids, input_lengths = await self._preprocess(query)
self._logger.debug(input_ids, input_lengths)
# [FIXME] not rate limiting due to metric polling is not supported
# KV aware routing
lora_id = 0
try:
self._generate.component_id = await self._router.schedule(
input_ids[0], lora_id
)
self._logger.debug(f"worker selected: {self._generate.component_id}")
except Exception as e:
if "No worker found" in str(e):
self._generate.component_id = None
self._logger.debug("no eligible worker")
else:
self._logger.exception(f"Error during selecting worker: {e}")
# [TODO] add disaggregated example
"""llm"""
llm_inputs = {}
llm_inputs["input_ids"] = input_ids
llm_inputs["input_lengths"] = input_lengths
llm_inputs["request_output_len"] = request_output_len
async for llm_response in await self._generate.async_infer(
inputs=llm_inputs,
):
self._logger.debug(f"llm response completed: {llm_response}")
background_tasks.append(
asyncio.create_task(
self._send_llm_response(
llm_response,
response_sender,
final=llm_response.final,
)
)
)
try:
results = await asyncio.gather(*background_tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
self._logger.exception(
f"Sending response failed with exception: {result}"
)
else:
self._logger.debug(f"Response sent successfully: {result}")
except Exception as e:
self._logger.exception(f"Error during response sending: {e}")
for output in llm_response.outputs:
del output
async def _preprocess(self, query):
start_ids = None
start_lengths = None
if isinstance(query, str):
query = [[query]]
async for preprocess_response in self._preprocess_model.async_infer(
inputs={"query": query}
):
self._logger.debug(f"Preprocess response completed: {preprocess_response}")
start_ids = numpy.from_dlpack(preprocess_response.outputs["start_ids"])
start_lengths = numpy.from_dlpack(
preprocess_response.outputs["start_lengths"]
)
return start_ids, start_lengths
async def _postprocessing(self, tokens_batch, sequence_lengths):
outputs = []
async for postprocess_response in self._postprocess_model.async_infer(
inputs={"tokens_batch": tokens_batch, "sequence_lengths": sequence_lengths}
):
self._logger.debug(f"Received postprocess response: {postprocess_response}")
output = postprocess_response.outputs["output"].to_string_array()
outputs.append(output)
return outputs
async def _send_llm_response(self, llm_response, response_sender, final):
tokens_batch = numpy.from_dlpack(llm_response.outputs["output_ids"])
self._logger.debug(f"Output ids length: {tokens_batch}")
sequence_length = numpy.from_dlpack(llm_response.outputs["sequence_length"])
output = await self._postprocessing(tokens_batch, sequence_length)
store_outputs_in_response = set()
if self._store_outputs_in_response:
store_outputs_in_response.add("text_output")
await response_sender.send(
outputs={"text_output": output[0]},
final=final,
store_outputs_in_response=store_outputs_in_response,
)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import threading
import time
import numpy as np
import triton_python_backend_utils as pb_utils
DEFAULT_OUTPUT_LEN = 1000
class TritonPythonModel:
def initialize(self, args):
self._logger = pb_utils.Logger
model_config = json.loads(args["model_config"])
self._generate_token_latency = (
float(
model_config["parameters"]["generate_token_latency_ms"]["string_value"]
)
) / 1000
self._context_token_latency = (
float(
model_config["parameters"]["context_token_latency_ms"]["string_value"]
)
) / 1000
using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
model_config
)
if not using_decoupled:
raise pb_utils.TritonModelException(
"""the model `{}` can generate any number of responses per request,
enable decoupled transaction policy in model configuration to
serve this model""".format(
args["model_name"]
)
)
for output_name in ["output_ids", "sequence_length", "context_phase_params"]:
setattr(
self,
output_name.lower() + "_dtype",
pb_utils.triton_string_to_numpy(
pb_utils.get_output_config_by_name(model_config, output_name)[
"data_type"
]
),
)
# To keep track of response threads so that we can delay
# the finalizing the model until all response threads
# have completed.
self.inflight_thread_count = 0
self.inflight_thread_count_lck = threading.Lock()
def response_thread(self, response_sender, inputs):
streaming = inputs["streaming"][0]
request_type = inputs["request_type"]
output_ids = []
output_sequence_length = inputs["request_output_len"][0]
self._logger.log_verbose(
f"Starting Response Thread: {threading.get_native_id()}"
)
self._logger.log_verbose(f"Inputs: {inputs}")
self._logger.log_verbose(f"Streaming: {streaming}")
self._logger.log_verbose(f"Request Type: {request_type}")
input_sequence_length = inputs["input_lengths"][0]
if inputs["request_type"] != "generate_only":
for _ in inputs["input_ids"][0]:
time.sleep(self._context_token_latency)
if request_type != "context_only":
for index in range(output_sequence_length):
output_ids.append(inputs["input_ids"][0][index % input_sequence_length])
if streaming:
output_ids_tensor = pb_utils.Tensor(
"output_ids",
np.array([[[output_ids[-1]]]]).astype(self.output_ids_dtype),
)
sequence_length_tensor = pb_utils.Tensor(
"sequence_length",
np.array([[1]]).astype(self.sequence_length_dtype),
)
response = pb_utils.InferenceResponse(
output_tensors=[output_ids_tensor, sequence_length_tensor]
)
flags = (
pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
if index == output_sequence_length - 1
else 0
)
response_sender.send(response, flags=flags)
time.sleep(self._generate_token_latency)
if not streaming:
output_ids_tensor = pb_utils.Tensor(
"output_ids", np.array([[output_ids]]).astype(self.output_ids_dtype)
)
sequence_length_tensor = pb_utils.Tensor(
"sequence_length",
np.array([[output_sequence_length]]).astype(
self.sequence_length_dtype
),
)
response = pb_utils.InferenceResponse(
output_tensors=[output_ids_tensor, sequence_length_tensor]
)
response_sender.send(
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
)
if request_type == "context_only":
output_ids.append(inputs["input_ids"][0][0])
output_ids_tensor = pb_utils.Tensor(
"output_ids",
np.array([[output_ids]]).astype(self.output_ids_dtype),
)
sequence_length_tensor = pb_utils.Tensor("sequence_length", np.array([[1]]))
context_phase_params = pb_utils.Tensor(
"context_phase_params",
np.array([[1, 2, 3, 4]]).astype(self.context_phase_params_dtype),
)
response = pb_utils.InferenceResponse(
output_tensors=[
output_ids_tensor,
sequence_length_tensor,
context_phase_params,
]
)
response_sender.send(
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
)
# We must close the response sender to indicate to Triton that we are
# done sending responses for the corresponding request. We can't use the
# response sender after closing it. The response sender is closed by
# setting the TRITONSERVER_RESPONSE_COMPLETE_FINAL.
# response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
with self.inflight_thread_count_lck:
self.inflight_thread_count -= 1
self._logger.log_verbose(
f"Exiting Response Thread: {threading.get_native_id()}"
)
def _get_inputs(self, request):
inputs = [
"context_phase_params",
"streaming",
"min_length",
"request_output_len",
"input_lengths",
"input_ids",
]
result = {}
for input_ in inputs:
value = pb_utils.get_input_tensor_by_name(request, input_)
if value is not None:
result[input_] = value.as_numpy()
input_parameters = json.loads(request.parameters())
if "request_type" in input_parameters:
result["request_type"] = input_parameters["request_type"]
else:
result["request_type"] = "aggregate"
if "request_output_len" not in inputs:
result["request_output_len"] = DEFAULT_OUTPUT_LEN
if "streaming" not in result:
result["streaming"] = [False]
return result
def execute(self, requests):
for idx, request in enumerate(requests):
inputs = self._get_inputs(request)
# Start a separate thread to send the responses for the request. The
# sending back the responses is delegated to this thread.
thread = threading.Thread(
target=self.response_thread,
args=(request.get_response_sender(), inputs),
)
# A model using decoupled transaction policy is not required to send all
# responses for the current request before returning from the execute.
# To demonstrate the flexibility of the decoupled API, we are running
# response thread entirely independent of the execute thread.
thread.daemon = True
with self.inflight_thread_count_lck:
self.inflight_thread_count += 1
thread.start()
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment