"lib/llm/src/preprocessor.rs" did not exist on "4f6f63cd1692d1929bafac54232f608550719aa4"
Commit 607fac29 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

refactor: openai server skeleton


Co-authored-by: default avatarrmccorm4 <21284872+rmccorm4@users.noreply.github.com>
Co-authored-by: default avatarpiotrm-nvidia <62554872+piotrm-nvidia@users.noreply.github.com>
parent c3b84790
......@@ -27,6 +27,25 @@ USER root
RUN apt-get update; apt-get install -y gdb
# Install OpenAI-compatible frontend and its dependencies from triton server
# repository. These are used to have a consistent interface, schema, and FastAPI
# app between Triton Core and Triton Distributed implementations.
# NOTE: Current commit is == r24.12 + enum serialization fix
ARG SERVER_OPENAI_COMMIT="2ebd762fa6c7b829e7d04bfaf80c8400a09d3767"
RUN mkdir -p /opt/tritonserver/python && \
cd /opt/tritonserver/python && \
rm -rf openai && \
git clone https://github.com/triton-inference-server/server.git && \
cd server && \
git checkout ${SERVER_OPENAI_COMMIT} && \
cd .. && \
mv server/python/openai openai && \
chown -R root:root openai && \
chmod 755 openai && \
chmod -R go-w openai && \
rm -rf server && \
python3 -m pip install -r openai/requirements.txt
# Common dependencies
RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requirements.txt \
pip install --timeout=2000 --requirement /tmp/requirements.txt
......@@ -99,22 +118,6 @@ RUN rm -rf /etc/nginx/sites-enabled/default
RUN apt-get install nvtop -y
RUN apt-get install tmux -y
# Install OpenAI-compatible frontend and its dependencies
# NOTE: Has a couple compat fixes needed for kserve frontends in 24.11 release.
# Can be replaced with specific release like r24.12 in future for stability.
ARG SERVER_OPENAI_COMMIT="f336fa6bd5416ba2f17e5eb7de5228213562bbc8"
WORKDIR /opt/tritonserver
RUN git clone https://github.com/triton-inference-server/server.git && \
cd server && \
git checkout ${SERVER_OPENAI_COMMIT} && \
cd .. && \
mv server/python/openai openai && \
chown -R root:root openai && \
chmod 755 openai && \
chmod -R go-w openai && \
rm -rf server && \
python3 -m pip install -r openai/requirements.txt
##########################################################
# Tokenizers #
##########################################################
......@@ -133,7 +136,7 @@ COPY . /workspace
RUN /workspace/icp/protos/gen_python.sh
# Sets pythonpath for python modules
ENV PYTHONPATH="${PYTHONPATH}:/workspace/icp/src/python:/workspace/worker/src/python:/workspace/examples"
ENV PYTHONPATH="${PYTHONPATH}:/workspace/icp/src/python:/workspace/worker/src/python:/workspace/examples:/opt/tritonserver/python/openai/openai_frontend"
# Command and Entrypoint
CMD []
......
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
fastapi==0.111.0
fastapi==0.115.6
ftfy
grpcio-tools==1.66.0
httpx
......@@ -23,6 +23,7 @@ opentelemetry-api
opentelemetry-sdk
pre-commit
protobuf==5.27.3
pydantic==2.7.1
pyright
pytest-md-report
pytest-mypy
......@@ -31,4 +32,3 @@ transformers
tritonclient==2.53.0
# TODO: See whether TRT-LLM installs a different version of UCX. Need to revisit and track this dependency.
ucx-py-cu12
uvicorn==0.30.6
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from frontend.fastapi_frontend import FastApiFrontend
from llm.api_server.triton_distributed_engine import TritonDistributedEngine
from triton_distributed.worker.log_formatter import setup_logger
from .parser import parse_args
def main(args):
print(args)
logger = setup_logger(args.log_level, args.program_name)
logger.info("Starting")
# Wrap Triton Distributed in an interface-conforming "LLMEngine"
engine: TritonDistributedEngine = TritonDistributedEngine(
nats_url=args.request_plane_uri,
data_plane_host=args.data_plane_host,
data_plane_port=args.data_plane_port,
model_name=args.model_name,
tokenizer=args.tokenizer,
)
# Attach TritonLLMEngine as the backbone for inference and model management
openai_frontend: FastApiFrontend = FastApiFrontend(
engine=engine,
host=args.api_server_host,
port=args.api_server_port,
log_level=args.log_level,
)
# Blocking call until killed or interrupted with SIGINT
openai_frontend.start()
if __name__ == "__main__":
parser, args = parse_args()
args.program_name = parser.prog
main(args)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from llm.api_server.connector import (
BaseTriton3Connector,
InferenceRequest,
InferenceResponse,
TritonInferenceError,
)
# FIXME: Integrate better with api_server library
from schemas.openai import (
ChatCompletionChoice,
ChatCompletionFinishReason,
ChatCompletionResponseMessage,
ChatCompletionStreamingResponseChoice,
ChatCompletionStreamResponseDelta,
CreateChatCompletionRequest,
CreateChatCompletionResponse,
CreateChatCompletionStreamResponse,
ObjectType,
)
from transformers import AutoTokenizer
LOGGER = logging.getLogger(__name__)
"""
Example request with curl
curl -X 'POST' \\
'http://{host}:{port}/v1/chat/completions' \\
-H 'accept: application/json' \\
-H 'Content-Type: application/json' \\
-d '{{
"model": "{model}",
"messages": [
{{
"role":"user",
"content":"Hello! How are you?"
}},
{{
"role":"assistant",
"content":"Hi! I am quite well, how can I help you today?"
}},
{{
"role":"user",
"content":"Can you write me a song?"
}}
],
"top_p": 1,
"n": 1,
"max_tokens": 15,
"stream": true,
"frequency_penalty": 1.0,
"stop": ["hello"]
}}'
"""
def generate_sampling_params(
request: CreateChatCompletionRequest,
non_supported_params_none: Optional[List[str]] = None,
) -> Dict[str, Any]:
errors_message = ""
if not non_supported_params_none:
non_supported_params_none = []
for param in non_supported_params_none:
if getattr(request, param, None) is not None:
errors_message += f"The parameter '{param}' is not supported. "
if errors_message:
raise ValueError(errors_message)
sampling_params = {}
if request.temperature is not None:
sampling_params["temperature"] = request.temperature
if request.n is not None:
sampling_params["n"] = request.n
if request.top_p is not None:
sampling_params["top_p"] = request.top_p
if request.presence_penalty is not None:
sampling_params["presence_penalty"] = request.presence_penalty
if request.frequency_penalty is not None:
sampling_params["frequency_penalty"] = request.frequency_penalty
if request.max_tokens is not None:
sampling_params["max_tokens"] = request.max_tokens
if request.min_tokens is not None:
sampling_params["min_tokens"] = request.min_tokens
if request.stop is not None:
sampling_params["stop"] = request.stop
if request.seed is not None:
sampling_params["seed"] = request.seed
return sampling_params
def create_chat_response(
request_id: str,
model: str,
model_output: Union[np.ndarray, List[str]],
role: str,
prompt: str,
) -> CreateChatCompletionResponse:
"""Create chunk responses from the detokenized outputs for non-streaming completions"""
detokenized_outputs = model_output
# Extract prompt from detokenized_outputs
cleaned_outputs = []
for detokenized_output in detokenized_outputs:
# FIXME: Should this be handled by 'echo' param instead?
if detokenized_output.startswith(prompt):
cleaned_output = detokenized_output[len(prompt) :]
else:
cleaned_output = detokenized_output
cleaned_outputs.append(cleaned_output)
messages = [
ChatCompletionResponseMessage(role=role, content=output_str)
for output_str in cleaned_outputs
]
choices = [
ChatCompletionChoice(
index=idx,
message=message,
finish_reason=ChatCompletionFinishReason.stop,
logprobs=None,
)
for idx, message in enumerate(messages)
]
chat_response = CreateChatCompletionResponse(
id=request_id,
choices=choices,
created=int(time.time()),
model=model,
system_fingerprint=None,
object=ObjectType.chat_completion,
)
return chat_response
def generate_delta(
output_str: str, role: str, previous_output: Optional[str] = None
) -> ChatCompletionStreamResponseDelta:
"""Generate the delta from the output string
Args:
output_str (str): The output string from the model.
role (str): The role of the AI generating the output.
previous_output (Optional[str]): The previous output string. Defaults to None.
Example:
print(generate_delta("user: Hello!, assistant: Hi!", "assistant", "user: Hello!, assistant: "))
# Output: Delta(role='assistant', content='Hi!')
"""
if previous_output is None:
return ChatCompletionStreamResponseDelta(role=role, content=output_str)
else:
# FIXME: Should we be manually finding the delta here? Or full text from last response?
delta_start = output_str.find(previous_output)
if delta_start == -1:
LOGGER.warning(
f"Previous output \n<START>\n{previous_output}\n<END>\n not found in the output string: \n<START>\n{output_str}\n<END>\n"
)
return ChatCompletionStreamResponseDelta(role=role, content=output_str)
delta = output_str[delta_start + len(previous_output) :]
return ChatCompletionStreamResponseDelta(role=role, content=delta)
def create_chunk_responses(
request_id: str,
model: str,
model_output: List[str],
role: str,
previous_output: Optional[list[str]] = None,
logprobs: Optional[np.ndarray] = None,
finish_reason: Optional[str] = None,
) -> Tuple[CreateChatCompletionStreamResponse, List[str]]:
"""Create chunk responses from the detokenized outputs for streaming completions.
Function extracts the delta from the output string and creates a chunk response. It also updates the previous output.
Args:
request_id (str): The unique identifier for the request.
model (str): The model used for generating the response.
model_output (List[str]): The list of output strings from the model.
role (str): The role of the AI generating the output.
previous_output (Optional[str]): The previous output string. Defaults to None.
logprobs (Optional[np.ndarray]): The log probabilities of the output tokens. Defaults to None.
finish_reason (Optional[str]): The reason for stopping the completion. Defaults to None.
Returns:
Tuple[CreateChatCompletionStreamResponse, List[str]]: A tuple containing the chunk response and the new previous output.
Example:
create_chunk_responses(
request_id="chatcmpl-123",
model="gpt-4o-mini",
model_output=["I am fine, thank you!", "How can I help you?"],
role="assistant",
previous_output="user: Hello!, assistant: ",
logprobs=np.array([[0.1, 0.2, 0.3, 0.4, 0.5]]),
finish_reason="stop"
))
"""
detokenized_outputs = model_output
new_previous_output = []
deltas = []
for idx, output_str in enumerate(detokenized_outputs):
if previous_output is not None:
previous_output_row = previous_output[idx]
else:
previous_output_row = None
delta = generate_delta(
output_str=output_str, role=role, previous_output=previous_output_row
)
deltas.append(delta)
new_previous_output.append(output_str)
choices = []
for idx, delta in enumerate(deltas):
choice_kwargs = {
"index": idx,
"delta": delta,
# FIXME: Validate finish_reason behavior on first vs last responses
"finish_reason": finish_reason,
"logprobs": None,
}
if logprobs is not None:
choice_kwargs["logprobs"] = logprobs[idx]
chunk_choice = ChatCompletionStreamingResponseChoice(**choice_kwargs)
choices.append(chunk_choice)
chunk_response = CreateChatCompletionStreamResponse(
id=request_id,
object=ObjectType.chat_completion_chunk,
created=int(time.time()),
model=model,
system_fingerprint=request_id,
choices=choices,
)
return chunk_response, new_previous_output
class ChatHandler:
def __init__(self, triton_connector: BaseTriton3Connector, tokenizer: str):
self._triton_connector = triton_connector
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
def translate_chat_inputs(
self, request: CreateChatCompletionRequest, request_id: str, prompt: str
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
raise NotImplementedError("This method is not implemented yet")
def translate_chat_outputs(
self, response: InferenceResponse, model_name: str
) -> Dict[str, Any]:
raise NotImplementedError("This method is not implemented yet")
def stream_response_adaptor(self, response_stream):
async def adaptor_stream():
async for response in response_stream():
if isinstance(response, Exception):
yield self.exception_adaptor(response).body
else:
yield response.model_dump_json() + "\n"
return StreamingResponse(adaptor_stream(), media_type="application/json")
def response_adaptor(self, response):
return response.model_dump_json()
def exception_adaptor(self, exception):
return JSONResponse(
content={"error": str(exception), "code": 500}, status_code=500
)
async def process_request(self, request: Any, raw_request: Optional[Request]):
request_id = str(uuid.uuid4())
LOGGER.debug(f"{request=}")
prompt, role = self._create_prompt(request)
inputs, parameters = self.translate_chat_inputs(request, request_id, prompt)
triton_request = InferenceRequest(inputs=inputs, parameters=parameters)
# Streaming
if request.stream:
response_stream = self._stream_response_factory(
request_id, request.model, triton_request, prompt, role
)
return self.stream_response_adaptor(response_stream)
# Non-Streaming
response_data = None
try:
chat_outputs = None
async for response in self._triton_connector.inference(
request.model, triton_request
):
chat_outputs = self.translate_chat_outputs(response, request.model)
kwargs = {
"request_id": request_id,
"model": request.model,
"role": role,
"prompt": prompt,
}
if chat_outputs is not None:
kwargs.update(chat_outputs)
response_data = create_chat_response(**kwargs)
except TritonInferenceError as e:
logging.error(f"Error processing chat completion request: {e}")
return self.exception_adaptor(e)
LOGGER.info(f"Chat completion response: {response_data}")
return self.response_adaptor(response_data)
def _stream_response_factory(self, request_id, model, triton_request, prompt, role):
async def stream_response():
try:
previous_output = None
async for response in self._triton_connector.inference(
model, triton_request
):
# FIXME: Detect stop in response
try:
chat_outputs = self.translate_chat_outputs(response, model)
except KeyError as e:
LOGGER.info(f"KeyError {e} in response: {response}")
break
model_output = chat_outputs["model_output"]
chunk_response, new_previous_output = create_chunk_responses(
request_id=request_id,
model=model,
model_output=model_output,
role=role,
previous_output=previous_output,
)
previous_output = new_previous_output
yield f"data: {chunk_response.model_dump_json(exclude_unset=True)}\n\n"
except TritonInferenceError as e:
logging.error(f"Error processing chat completion request: {e}")
# FIXME: Does this need to conform to SSE standard for errors?
yield JSONResponse(
content={"error": str(e), "code": 500}, status_code=500
).body
finally:
yield "data: [DONE]\n\n"
return stream_response
# FIXME: Use shared/common module for these functions between
# TritonLLMEngine and TritonDistributedEngine implementations.
def _get_first_response_role(
self, conversation: List[Dict], add_generation_prompt: bool, default_role: str
) -> str:
if add_generation_prompt:
return default_role
return conversation[-1]["role"]
def _create_prompt(self, request: CreateChatCompletionRequest) -> Tuple[str, str]:
"""Create a prompt for vLLM model from the messages"""
conversation = [
message.model_dump(exclude_none=True) for message in request.messages
]
add_generation_prompt = True
prompt = self.tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
LOGGER.debug(f"{prompt=}")
default_role = "assistant"
role = self._get_first_response_role(
conversation, add_generation_prompt, default_role
)
return prompt.strip(), role
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from llm.api_server.chat import ChatHandler, generate_sampling_params
from llm.api_server.connector import BaseTriton3Connector, InferenceResponse
from schemas.openai import CreateChatCompletionRequest
LOGGER = logging.getLogger(__name__)
# FIXME: Share request conversion logic where applicable
def generate_sampling_params_vllm(
request: CreateChatCompletionRequest,
non_supported_params: Optional[List[str]] = None,
) -> dict:
"""
Generate sampling params for vLLM from the request.
Args:
request: CreateChatCompletionRequest object.
Returns:
dict: Sampling params for vLLM.
"""
errors_message = ""
if request.logprobs:
errors_message += "The parameter 'logprobs' set to True is not supported. "
if request.tools and request.tools.type != "text":
errors_message += (
f"The parameter 'tools' type {request.tools.type} is not supported. "
)
if errors_message:
raise ValueError(errors_message)
if non_supported_params is None:
non_supported_params = [
"logit_bias",
"top_logprobs",
"tool_choice",
"user",
"service_tier",
]
sampling_params = generate_sampling_params(request, non_supported_params)
# NOTE: vLLM parameters (ex: top_k) not supported until added to schema
return sampling_params
class ChatHandlerVllm(ChatHandler):
def __init__(
self, triton_connector: BaseTriton3Connector, model_name: str, tokenizer: str
):
super().__init__(triton_connector, tokenizer)
self._model_name = model_name
def translate_chat_inputs(
self, request: CreateChatCompletionRequest, request_id: str, prompt: str
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""Translate the chat completion request to inference request"""
if self._model_name is not None and self._model_name != request.model:
raise ValueError(
f"Model name mismatch: {self._model_name} != {request.model}"
)
inputs: Dict[str, np.ndarray] = {}
sampling_params = generate_sampling_params_vllm(request)
parameters = {
"sampling_params": sampling_params,
"request_id": request_id,
"prompt": prompt,
}
return inputs, parameters
def translate_chat_outputs(
self, response: InferenceResponse, model_name: str
) -> Dict[str, Any]:
"""Translate the inference outputs to chat completion response"""
return {"model_output": [response.parameters["text"]]}
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import dataclasses
import typing
class TritonInferenceError(Exception):
"""Error occurred during Triton inference."""
@dataclasses.dataclass
class InferenceRequest:
"""Inference request."""
inputs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
parameters: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class InferenceResponse:
"""Inference response."""
outputs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
error: typing.Optional[str] = None
final: bool = False
parameters: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
class BaseTriton3Connector(abc.ABC):
"""Base class for Triton 3 connector."""
@abc.abstractmethod
def inference(
self, model_name: str, request: InferenceRequest
) -> typing.AsyncGenerator[InferenceResponse, None]:
"""Inference request to Triton 3 system.
Args:
model_name: Model name.
request: Inference request.
Returns:
Inference response.
Raises:
TritonInferenceError: error occurred during inference.
"""
raise NotImplementedError
async def list_models(self) -> typing.List[str]:
"""List models available in Triton 3 system.
Returns:
List of model names.
"""
raise NotImplementedError
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description="OpenAI-Compatible API server.", prog="OpenAI API Sever"
)
# API Server
parser.add_argument(
"--api-server-host",
type=str,
required=False,
default="127.0.0.1",
help="API Server host",
)
parser.add_argument(
"--api-server-port",
type=int,
required=False,
default=8000,
help="API Server port",
)
# Request Plane
parser.add_argument(
"--request-plane-uri",
type=str,
required=False,
default="nats://localhost:4223",
help="URL of request plane",
)
# Data Plane
parser.add_argument(
"--data-plane-host",
type=str,
required=False,
default=None,
help="Data plane host",
)
parser.add_argument(
"--data-plane-port",
type=int,
required=False,
default=0,
help="Data plane port. (default: 0 means the system will choose a port)",
)
# Misc
parser.add_argument(
"--tokenizer",
type=str,
required=False,
default="meta-llama/Meta-Llama-3.1-8B-Instruct",
help="Tokenizer to use for chat template in chat completions API",
)
parser.add_argument(
"--model-name",
type=str,
required=False,
default="prefill",
help="Model name",
)
parser.add_argument(
"--log-level",
type=int,
required=False,
default=1,
help="Logging level",
)
return parser, parser.parse_args()
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
# UCP data plane causes deadlocks when used more than once, so we use a singleton
_g_singletonic_data_plane = None
_g_singletonic_data_plane_connection_count = 0
_g_actual_host = None
_g_actual_port = None
def set_actual_host_port(host, port):
global _g_actual_host
global _g_actual_port
if _g_singletonic_data_plane is not None:
raise Exception("Cannot set actual host and port after data plane is created")
_g_actual_host = host
_g_actual_port = port
def set_data_plane(data_plane):
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
_g_singletonic_data_plane_connection_count = 1
_g_singletonic_data_plane = data_plane
class RemoteConnector:
"""Handle connection to both request and data planes."""
def __init__(
self,
nats_url: str,
data_plane_host: Optional[str] = None,
data_plane_port: int = 0,
keep_dataplane_endpoints_open: bool = False,
):
"""Initialize RemoteConnector.
Args:
nats_url (str): URL of NATS server.
"""
global _g_singletonic_data_plane
global _g_actual_port
global _g_actual_host
self._nats_url = nats_url
self._request_plane = NatsRequestPlane(nats_url)
if _g_singletonic_data_plane is None:
if _g_actual_host is not None:
data_plane_host = _g_actual_host
if _g_actual_port is not None:
data_plane_port = _g_actual_port
_g_singletonic_data_plane = UcpDataPlane(
hostname=data_plane_host,
port=data_plane_port,
keep_endpoints_open=keep_dataplane_endpoints_open,
)
self._connected = False
self._data_plane = _g_singletonic_data_plane
async def connect(self):
"""Connect to both request and data planes."""
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
assert _g_singletonic_data_plane is not None
await self._request_plane.connect()
if _g_singletonic_data_plane_connection_count == 0:
_g_singletonic_data_plane.connect()
_g_singletonic_data_plane_connection_count += 1
self._connected = True
async def close(self):
"""Disconnect from both request and data planes."""
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
assert _g_singletonic_data_plane is not None
await self._request_plane.close()
_g_singletonic_data_plane_connection_count -= 1
if _g_singletonic_data_plane_connection_count == 0:
_g_singletonic_data_plane.close()
_g_singletonic_data_plane = None
self._data_plane.close()
self._connected = False
async def __aenter__(self):
"""Enter context manager."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager."""
await self.close()
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import typing
from typing import Optional
import numpy as np
from llm.api_server.connector import (
BaseTriton3Connector,
InferenceRequest,
InferenceResponse,
)
from llm.api_server.remote_connector import RemoteConnector
from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.remote_tensor import RemoteTensor
class RemoteModelConnector(BaseTriton3Connector):
"""Connector for Triton 3 model."""
def __init__(
self,
nats_url: str,
model_name: str,
model_version: Optional[str] = None,
data_plane_host: Optional[str] = None,
data_plane_port: int = 0,
keep_dataplane_endpoints_open: bool = False,
):
"""Initialize Triton 3 connector.
Args:
nats_url: NATS URL (e.g. "localhost:4222").
model_name: Model name.
model_version: Model version. Default is "1".
data_plane_host: Data plane host (e.g. "localhost").
data_plane_port: Data plane port (e.g. 8001). You can use 0 to let the system choose a port.
keep_dataplane_endpoints_open: Keep data plane endpoints open to avoid reconnecting. Default is False.
Example:
remote_model_connector = RemoteModelConnector(
nats_url="localhost:4222",
data_plane_host="localhost",
data_plane_port=8001,
model_name="model_name",
model_version="1",
)
async with remote_model_connector:
request = InferenceRequest(inputs={"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])})
async for response in remote_model_connector.inference(model_name="model_name", request=request):
print(response.outputs)
"""
self._connector = RemoteConnector(
nats_url,
data_plane_host,
data_plane_port,
keep_dataplane_endpoints_open=keep_dataplane_endpoints_open,
)
self._model_name = model_name
if model_version is None:
model_version = "1"
self._model_version = model_version
async def connect(self):
"""Connect to Triton 3 server."""
await self._connector.connect()
self._model = RemoteOperator(
operator=(self._model_name, self._model_version),
request_plane=self._connector._request_plane,
data_plane=self._connector._data_plane,
)
async def close(self):
"""Disconnect from Triton 3 server."""
await self._connector.close()
async def __aenter__(self):
"""Enter context manager."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager."""
await self.close()
async def inference(
self, model_name: str, request: InferenceRequest
) -> typing.AsyncGenerator[InferenceResponse, None]:
"""Inference request to Triton 3 system.
Args:
model_name: Model name.
request: Inference request.
Returns:
Inference response.
Raises:
TritonInferenceError: error occurred during inference.
"""
if not self._connector._connected:
await self.connect()
else:
if self._model_name != model_name:
self._model_name = model_name
self._model = RemoteOperator(
self._model_name,
self._connector._request_plane,
self._connector._data_plane,
)
results = []
for key, value in request.parameters.items():
if isinstance(value, dict):
request.parameters[key] = "JSON:" + json.dumps(value)
results.append(
self._model.async_infer(
inputs=request.inputs,
parameters=request.parameters,
)
)
for result in asyncio.as_completed(results):
responses = await result
async for response in responses:
triton_response = response.to_model_infer_response(
self._connector._data_plane
)
outputs = {}
for output in triton_response.outputs:
remote_tensor = RemoteTensor(output, self._connector._data_plane)
try:
local_tensor = remote_tensor.local_tensor
numpy_tensor = np.from_dlpack(local_tensor)
finally:
# FIXME: This is a workaround for the issue that the remote tensor
# is released after connection is closed.
remote_tensor.__del__()
outputs[output.name] = numpy_tensor
infer_response = InferenceResponse(
outputs=outputs,
final=response.final,
parameters=response.parameters,
)
yield infer_response
async def list_models(self) -> typing.List[str]:
"""List models available in Triton 3 system.
Returns:
List of model names.
"""
# FIXME: Support multiple models
return [self._model_name]
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
pytestmark = pytest.mark.pre_merge
def test_imports():
from engine.engine import LLMEngine as e
from frontend.fastapi_frontend import FastApiFrontend as f
# Placeholder to avoid unused import errors or removal by linters
assert e, f
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import AsyncIterator
from engine.engine import LLMEngine
from llm.api_server.chat_vllm import ChatHandlerVllm
from llm.api_server.remote_model_connector import RemoteModelConnector
from schemas.openai import (
CreateChatCompletionRequest,
CreateChatCompletionResponse,
CreateCompletionRequest,
CreateCompletionResponse,
Model,
ObjectType,
)
class TritonDistributedChatHandler(ChatHandlerVllm):
def __init__(
self, triton_connector: RemoteModelConnector, model_name: str, tokenizer: str
):
super().__init__(triton_connector, model_name, tokenizer)
# Request / response format can vary between frontends, so allow override
# of adaptor functions accordingly.
def stream_response_adaptor(self, response_stream):
async def adaptor_stream():
async for response in response_stream():
if isinstance(response, Exception):
raise response
else:
# Already in SSE String format
yield response
return adaptor_stream
def response_adaptor(self, response):
return response
def exception_adaptor(self, exception):
raise exception
class TritonDistributedEngine(LLMEngine):
def __init__(
self,
nats_url: str,
data_plane_host: str,
data_plane_port: int,
model_name: str,
tokenizer: str,
):
self.triton_connector = RemoteModelConnector(
nats_url=nats_url,
data_plane_host=data_plane_host,
data_plane_port=data_plane_port,
model_name=model_name,
keep_dataplane_endpoints_open=True,
)
# FIXME: Consider supporting multiple or per-model tokenizers
self.request_handler = TritonDistributedChatHandler(
self.triton_connector, model_name, tokenizer
)
async def chat(
self, request: CreateChatCompletionRequest
) -> CreateChatCompletionResponse | AsyncIterator[str]:
"""
If request.stream is True, this returns an AsyncIterator (or Generator) that
produces server-sent-event (SSE) strings in the following form:
'data: {CreateChatCompletionStreamResponse}\n\n'
...
'data: [DONE]\n\n'
If request.stream is False, this returns a CreateChatCompletionResponse.
"""
# FIXME: Unify call whether streaming or not
if request.stream:
response_generator = await self.request_handler.process_request(
request, None
)
return response_generator()
response = await self.request_handler.process_request(request, None)
return response
async def completion(
self, request: CreateCompletionRequest
) -> CreateCompletionResponse | AsyncIterator[str]:
"""
If request.stream is True, this returns an AsyncIterator (or Generator) that
produces server-sent-event (SSE) strings in the following form:
'data: {CreateCompletionResponse}\n\n'
...
'data: [DONE]\n\n'
If request.stream is False, this returns a CreateCompletionResponse.
"""
raise NotImplementedError
def ready(self) -> bool:
"""
Returns True if the engine is ready to accept inference requests, or False otherwise.
"""
# FIXME: Add more useful checks if available.
return True
def metrics(self) -> str:
"""
Returns the engine's metrics in a Prometheus-compatible string format.
"""
raise NotImplementedError
def models(self) -> list[Model]:
"""
Returns a List of OpenAI Model objects.
"""
# FIXME: Support 'async def models()'
model_names = asyncio.run(self.triton_connector.list_models())
models = [
Model(
id=model_name,
object=ObjectType.model,
owned_by="Triton Distributed",
# FIXME: Need to track creation times, so set to 0 for now.
created=0,
)
for model_name in model_names
]
return models
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Union
import numpy as np
from fastapi import Header, HTTPException
# Utility function to convert response to JSON
def tensor_to_json(tensor: np.ndarray) -> Any:
"""Convert numpy tensor to JSON."""
if tensor.dtype.type is np.bytes_:
items = list([item.decode("utf-8") for item in tensor.flat])
if len(items) == 1:
try:
json_object = json.loads(items[0])
return json_object
except Exception:
return items[0]
return items
return tensor.tolist()
def json_to_tensor(json_list: str) -> np.ndarray:
"""Convert JSON to numpy tensor."""
return np.char.encode(json_list, "utf-8")
def verify_headers(content_type: Union[str, None] = Header(None)):
"""Verify content type."""
if content_type != "application/json":
raise HTTPException(
status_code=415,
detail="Unsupported media type: {content_type}. It must be application/json",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment