Commit f1f29171 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

feat: initial worker

parent b0195f54
...@@ -50,3 +50,5 @@ tensorrtllm_checkpoints/ ...@@ -50,3 +50,5 @@ tensorrtllm_checkpoints/
tensorrtllm_engines/ tensorrtllm_engines/
api_server_models/ api_server_models/
server/ server/
**/*backups*
\ No newline at end of file
...@@ -77,6 +77,7 @@ repos: ...@@ -77,6 +77,7 @@ repos:
rev: v0.5.2 rev: v0.5.2
hooks: hooks:
- id: ruff - id: ruff
args: [--fix, --exit-non-zero-on-fix]
# NOTE: pyright may be able to find other classes of errors not covered above, # NOTE: pyright may be able to find other classes of errors not covered above,
# but would require some configuring and venv setup to properly eliminate noise # but would require some configuring and venv setup to properly eliminate noise
......
...@@ -118,7 +118,7 @@ COPY . /workspace ...@@ -118,7 +118,7 @@ COPY . /workspace
RUN /workspace/icp/protos/gen_python.sh RUN /workspace/icp/protos/gen_python.sh
# Sets pythonpath for python modules # Sets pythonpath for python modules
ENV PYTHONPATH="${PYTHONPATH}:/workspace/icp/src/python" ENV PYTHONPATH="${PYTHONPATH}:/workspace/icp/src/python:/workspace/worker/src/python"
# Command and Entrypoint # Command and Entrypoint
CMD [] CMD []
......
...@@ -46,7 +46,12 @@ skip = ["build"] ...@@ -46,7 +46,12 @@ skip = ["build"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
minversion = "8.0" minversion = "8.0"
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config", "--mypy"]
# NOTE
# We ignore model.py explcitly here to avoid mypy errors with duplicate modules
# pytest overrides the default mypy exclude configuration and so we exclude here as well
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config", "--mypy", "--ignore-glob=*model.py"]
xfail_strict = true xfail_strict = true
log_cli_level = "INFO" log_cli_level = "INFO"
filterwarnings = [ filterwarnings = [
...@@ -67,12 +72,19 @@ line-length = 88 ...@@ -67,12 +72,19 @@ line-length = 88
indent-width = 4 indent-width = 4
[tool.mypy] [tool.mypy]
# --disable-error-code: WAR large set of errors due to mypy not being run # --disable-error-code: WAR large set of errors due to mypy not being run
# previously. We can slowly enable sets of errors to fix over time. # previously. We can slowly enable sets of errors to fix over time.
#disable_error_code = "arg-type,assignment,attr-defined,call-arg,call-overload,has-type,import-untyped,misc,operator,override,union-attr,var-annotated" # disable_error_code = []
# --explicit-package-bases: WAR errors about duplicate module names used # --explicit-package-bases: WAR errors about duplicate module names used
# throughout project such as launch_workers.py # throughout project such as launch_workers.py
explicit_package_bases = true # explicit_package_bases = true
# NOTE
# We ignore model.py explcitly here to avoid mypy errors with duplicate modules
exclude = ["model.py"]
# --ignore-missing-imports: WAR too many errors when developing outside # --ignore-missing-imports: WAR too many errors when developing outside
# of container environment with PYTHONPATH set and packages installed. # of container environment with PYTHONPATH set and packages installed.
# NOTE: Can possibly move mypy from pre-commit to a github action run only in # NOTE: Can possibly move mypy from pre-commit to a github action run only in
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from triton_distributed.worker.deployment import Deployment as Deployment
from triton_distributed.worker.operator import Operator as Operator
from triton_distributed.worker.operator import OperatorConfig as OperatorConfig
from triton_distributed.worker.remote_operator import RemoteOperator as RemoteOperator
from triton_distributed.worker.remote_request import (
RemoteInferenceRequest as RemoteInferenceRequest,
)
from triton_distributed.worker.remote_response import (
RemoteInferenceResponse as RemoteInferenceResponse,
)
from triton_distributed.worker.worker import Worker as Worker
from triton_distributed.worker.worker import WorkerConfig as WorkerConfig
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
from triton_distributed.worker.worker import Worker, WorkerConfig
class Deployment:
def __init__(self, worker_configs: list[WorkerConfig]):
self._process_context = multiprocessing.get_context("spawn")
self._worker_configs = worker_configs
self._workers: list[multiprocessing.context.SpawnProcess] = []
@staticmethod
def _start_worker(worker_config):
Worker(worker_config).start()
def start(self):
for worker_config in self._worker_configs:
self._workers.append(
self._process_context.Process(
target=Deployment._start_worker,
name=worker_config.name,
args=[worker_config],
)
)
def shutdown(self, join=True, timeout=10):
for worker in self._workers:
worker.terminate()
if join:
for worker in self._workers:
worker.join(timeout)
for worker in self._workers:
if worker.is_alive():
worker.kill()
worker.join(timeout)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
LOGGER_NAME = "Triton Worker"
class LogFormatter(logging.Formatter):
"""Class to handle formatting of the logger outputs"""
def __init__(self, logger_name=LOGGER_NAME):
logger = logging.getLogger(logger_name)
self._log_level = logger.getEffectiveLevel()
self._logger_name = logger_name
super().__init__(datefmt="%H:%M:%S")
def format(self, record):
front = "%(asctime)s %(filename)s:%(lineno)s"
self._style._fmt = f"{front}[{self._logger_name}] %(levelname)s: %(message)s"
return super().format(record)
def setup_logger(log_level=1, logger_name=LOGGER_NAME):
if log_level == 0:
log_level = logging.ERROR
elif log_level == 1:
log_level = logging.INFO
else:
log_level = logging.DEBUG
logger = logging.getLogger(logger_name)
logger.setLevel(level=log_level)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(LogFormatter(logger_name=logger_name))
logger.addHandler(handler)
logger.propagate = True
return logger
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.parser import Parser
from triton_distributed.worker.worker import Worker
def main(args=None):
args, cli_parser = Parser.parse_args(args)
# TODO: Revisit the worklow args. To simplify.
worker = Worker(
request_plane=NatsRequestPlane(args.request_plane_uri),
data_plane=UcpDataPlane(),
log_level=args.log_level,
operators=cli_parser.operator_configs,
metrics_port=args.metrics_port,
log_dir=args.log_dir,
name=args.name,
triton_log_path=args.triton_log_path,
)
worker.start()
if __name__ == "__main__":
sys.exit(main())
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface for Operators"""
import abc
from dataclasses import dataclass, field
from typing import Any, Optional, Type
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.worker.remote_request import RemoteInferenceRequest
from tritonserver import Server
class Operator(abc.ABC):
@abc.abstractmethod
def __init__(
self,
name: str,
version: int,
triton_core: Server,
request_plane: RequestPlane,
data_plane: DataPlane,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
repository: Optional[str] = None,
logger: Optional[Any] = None,
):
pass
@abc.abstractmethod
async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
pass
@dataclass
class OperatorConfig:
"""
Holds the properties of a hosted operator
"""
name: str
implementation: str | Type[Operator]
repository: Optional[str] = None
version: int = 1
max_inflight_requests: int = 5
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
)
log_level: Optional[int] = None
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from triton_distributed.worker.worker import OperatorConfig
# Default values
DEFAULT_REQUEST_PLANE_URI = "nats://localhost:4222"
DEFAULT_LOG_LEVEL = 0
# Property keys
NAME = "name"
VERSION = "version"
MAX_INFLIGHT_REQUESTS = "max_inflight_requests"
PARAMETERS = "parameters"
MODULE = "module"
REPOSITORY = "repository"
IMPLEMENTATION = "implementation"
class InvalidArgumentError(Exception):
pass
def _parse_name_and_properties(args, valid_properties):
kind = "operator"
args_dict = {}
if len(args) == 1:
args_dict[NAME] = args[0]
else:
for arg in args:
values = arg.split(":")
if values[0] not in valid_properties:
raise InvalidArgumentError(
f"Unexpected property found for `--{kind}` found. Expected one of {valid_properties}, found {values[0]}"
)
args_dict[values[0]] = ":".join(values[1:])
if values[0] == PARAMETERS:
parameter_file_path = args_dict[values[0]]
if not os.path.exists(parameter_file_path):
args_dict[values[0]] = json.loads(args_dict[values[0]])
else:
with open(parameter_file_path, "r") as f:
args_dict[values[0]] = json.load(f)
if NAME not in args_dict.keys():
raise InvalidArgumentError(
f"`name` is a required property for `--{kind}`. Missing `name:<{kind}_name>`, in {args}"
)
return args_dict
def _validate_operator_args(operator_args):
valid_properties = [
NAME,
VERSION,
MAX_INFLIGHT_REQUESTS,
MODULE,
PARAMETERS,
REPOSITORY,
]
properties = _parse_name_and_properties(operator_args, valid_properties)
for int_property in [VERSION, MAX_INFLIGHT_REQUESTS]:
if int_property in properties.keys():
try:
int(properties[int_property])
except ValueError:
raise InvalidArgumentError(
f"Unexpected value provided for `{int_property}` for operator `{properties[NAME]}`. Expected an integer, Got {properties[int_property]}"
)
if MODULE not in properties.keys():
raise InvalidArgumentError(
f"{MODULE} property not provided for operator `{properties[NAME]}`. This is a required property."
)
properties[IMPLEMENTATION] = properties[MODULE]
properties.pop(MODULE)
return properties
class Parser:
@classmethod
def _validate_args(cls, args):
operator_configs: list[OperatorConfig] = []
for operator_args in args.operators:
operator_properties = _validate_operator_args(operator_args)
operator_configs.append(OperatorConfig(**operator_properties))
args.operator_configs = operator_configs
# TODO: Add validation for request plane URI
@classmethod
def parse_args(cls, args=None):
parser = argparse.ArgumentParser(description="Triton Worker Component")
parser.add_argument(
"-c",
"--request-plane-uri",
type=str,
default=DEFAULT_REQUEST_PLANE_URI,
help="Request plane URI for the worker",
)
parser.add_argument(
"-l",
"--log-level",
type=int,
default=DEFAULT_LOG_LEVEL,
help="The logging level for Triton. The verbose logging can be enabled by specifying a value >= 1.",
)
parser.add_argument(
"--log-dir",
type=str,
default=None,
help="log dir folder",
)
parser.add_argument(
"--triton-log-path",
type=str,
default=None,
help="triton log path",
)
parser.add_argument(
"--name",
type=str,
default=None,
help="worker name",
)
parser.add_argument(
"-op",
"--operator",
type=str,
action="append",
nargs="+",
default=[],
dest="operators",
help="The operator to be hosted in the worker. The option can accept a single argument for the model name to load. Alternatively, it can also accept optional arguments in format `name:<model_name> version:<model_version>(optional) batch_size:<batch_size>(optional)`",
)
parser.add_argument(
"--metrics-port",
type=int,
default=0,
help="enable prometheus metrics for worker",
)
"""
TODO: Add more options as per requirements
"""
args = parser.parse_args(args)
try:
Parser._validate_args(args)
except Exception as err:
parser.error(f"Failed to validate arguments {err=}, {type(err)=}")
return args, cls
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for interacting with Triton Inference Server Models"""
import asyncio
import uuid
from typing import Optional
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.worker.remote_request import RemoteInferenceRequest
from triton_distributed.worker.remote_response import AsyncRemoteResponseIterator
from tritonserver import InvalidArgumentError
class RemoteOperator:
def __init__(
self,
name: str,
version: int,
request_plane: RequestPlane,
data_plane: DataPlane,
component_id: Optional[uuid.UUID] = None,
):
self.name = name
self.version = version
self._request_plane = request_plane
self._data_plane = data_plane
self.component_id = component_id
@property
def data_plane(self):
return self._data_plane
def create_request(self, **kwargs) -> RemoteInferenceRequest:
if "model_name" in kwargs:
kwargs.pop("model_name")
if "model_version" in kwargs:
kwargs.pop("model_version")
if "data_plane" in kwargs:
kwargs.pop("data_plane")
if "_request_plane" in kwargs:
kwargs.pop("_request_plane")
if "_model_infer_request" in kwargs:
kwargs.pop("_model_infer_request")
return RemoteInferenceRequest(
model_name=self.name,
model_version=self.version,
data_plane=self._data_plane,
_request_plane=None,
_model_infer_request=None,
**kwargs,
)
async def async_infer(
self,
inference_request: Optional[RemoteInferenceRequest] = None,
raise_on_error: bool = True,
**kwargs,
) -> AsyncRemoteResponseIterator:
if inference_request is None:
inference_request = RemoteInferenceRequest(
model_name=self.name,
model_version=self.version,
data_plane=self.data_plane,
_request_plane=None,
_model_infer_request=None,
**kwargs,
)
else:
inference_request.model_name = self.name
inference_request.model_version = self.version
if inference_request.data_plane != self.data_plane:
raise InvalidArgumentError(
"Data plane mismatch between remote request and remote operator: \n\n Operator: {self.data_plane} \n\n Request: {inference_request.data_plane}"
)
if (inference_request.response_queue is not None) and (
not isinstance(inference_request.response_queue, asyncio.Queue)
):
raise InvalidArgumentError(
"asyncio.Queue must be used for async response iterator"
)
response_iterator = AsyncRemoteResponseIterator(
self._data_plane,
inference_request,
inference_request.response_queue,
raise_on_error,
)
remote_inference_request = inference_request.to_model_infer_request()
await self._request_plane.post_request(
remote_inference_request,
response_handler=response_iterator._response_handler,
component_id=self.component_id,
)
return response_iterator
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for sending inference requests to Triton Inference Server Models"""
from __future__ import annotations
import asyncio
import queue
import uuid
from collections import Counter
from dataclasses import dataclass, field
from typing import Any, Optional
import tritonserver
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest
from triton_distributed.icp.request_plane import RequestPlane, get_icp_component_id
from triton_distributed.worker.remote_response import RemoteInferenceResponse
from triton_distributed.worker.remote_tensor import RemoteTensor
from tritonserver import InferenceRequest, InvalidArgumentError, Tensor
@dataclass
class RemoteInferenceRequest:
model_name: str
model_version: int
data_plane: DataPlane
component_id: Optional[uuid.UUID] = None
_request_plane: Optional[RequestPlane] = None
_model_infer_request: Optional[ModelInferRequest] = None
request_id: Optional[str] = None
correlation_id: Optional[int | str] = None
priority: Optional[int] = None
timeout: Optional[int] = None
inputs: dict[str, RemoteTensor | Any] = field(default_factory=dict)
store_inputs_in_request: set[str] = field(default_factory=set)
parameters: dict[str, str | int | bool | float] = field(default_factory=dict)
response_queue: Optional[queue.SimpleQueue | asyncio.Queue] = None
def _set_local_request_inputs(self, local_request: tritonserver.InferenceRequest):
for input_name, remote_tensor in self.inputs.items():
local_request.inputs[input_name] = remote_tensor.local_tensor
def _set_local_request_parameters(
self, local_request: tritonserver.InferenceRequest
):
for parameter_name, parameter_value in self.parameters.items():
local_request.parameters[parameter_name] = parameter_value
def _set_model_infer_request_inputs(
self,
remote_request: ModelInferRequest,
):
for name, value in self.inputs.items():
if not isinstance(value, RemoteTensor):
if not isinstance(value, Tensor):
tensor = Tensor._from_object(value)
else:
tensor = value
use_tensor_contents = name in self.store_inputs_in_request
remote_input = self.data_plane.put_input_tensor(
tensor, use_tensor_contents=use_tensor_contents
)
else:
remote_input = self.data_plane.create_input_tensor_reference(
value.remote_tensor
)
remote_input.name = name
remote_request.inputs.append(remote_input)
def _set_model_infer_request_parameters(self, remote_request: ModelInferRequest):
for key, value in self.parameters.items():
remote_value = remote_request.parameters[key]
if isinstance(value, str):
remote_value.string_param = value
elif isinstance(value, int):
remote_value.int64_param = value
elif isinstance(value, float):
remote_value.double_param = value
elif isinstance(value, bool):
remote_value.bool_param = value
else:
raise ValueError(f"Invalid parameter type: {type(value)}")
@staticmethod
def _set_parameters_from_model_infer_request(
result: RemoteInferenceRequest,
inference_request: ModelInferRequest,
):
for name, value in inference_request.parameters.items():
if value.HasField("bool_param"):
result.parameters[name] = value.bool_param
elif value.HasField("int64_param"):
result.parameters[name] = value.int64_param
elif value.HasField("double_param"):
result.parameters[name] = value.double_param
elif value.HasField("string_param"):
result.parameters[name] = value.string_param
@staticmethod
def _set_inputs_from_model_infer_request(
result: RemoteInferenceRequest,
inference_request: ModelInferRequest,
):
for remote_input in inference_request.inputs:
result.inputs[remote_input.name] = RemoteTensor(
remote_input, result.data_plane
)
def cancel(self):
raise NotImplementedError("Cancel not implemented")
def response_sender(self):
if self._request_plane is None or self._model_infer_request is None:
raise InvalidArgumentError(
"Response only valid for requests received from request plane"
)
return RemoteResponseSender(
self._model_infer_request, self._request_plane, self.data_plane
)
@staticmethod
def from_model_infer_request(
request: ModelInferRequest, data_plane: DataPlane, request_plane: RequestPlane
) -> RemoteInferenceRequest:
result = RemoteInferenceRequest(
request.model_name,
int(request.model_version),
data_plane,
_request_plane=request_plane,
_model_infer_request=request,
)
if request.id is not None:
result.request_id = request.id
result.component_id = get_icp_component_id(request)
if "sequence_id" in request.parameters:
if request.parameters["sequence_id"].HasField("string_param"):
result.correlation_id = request.parameters["sequence_id"].string_param
else:
result.correlation_id = request.parameters["sequence_id"].int64_param
if "priority" in request.parameters:
result.priority = request.parameters["priority"].uint64_param
if "timeout" in request.parameters:
result.timeout = request.parameters["timeout"].uint64_param
RemoteInferenceRequest._set_inputs_from_model_infer_request(result, request)
RemoteInferenceRequest._set_parameters_from_model_infer_request(result, request)
return result
def to_local_request(self, model: tritonserver.Model) -> InferenceRequest:
local_request = model.create_request()
if self.request_id is not None:
local_request.request_id = self.request_id
if self.priority is not None:
local_request.priority = self.priority
if self.timeout is not None:
local_request.timeout = self.timeout
if self.correlation_id is not None:
local_request.correlation_id = self.correlation_id
self._set_local_request_inputs(local_request)
self._set_local_request_parameters(local_request)
return local_request
def to_model_infer_request(self) -> ModelInferRequest:
remote_request = ModelInferRequest()
remote_request.model_name = self.model_name
remote_request.model_version = str(self.model_version)
if self.request_id is not None:
remote_request.id = self.request_id
if self.priority is not None:
remote_request.parameters["priority"].uint64_param = self.priority
if self.timeout is not None:
remote_request.parameters["timeout"].uint64_param = self.timeout
if self.correlation_id is not None:
if isinstance(self.correlation_id, str):
remote_request.parameters[
"sequence_id"
].string_param = self.correlation_id
else:
remote_request.parameters[
"sequence_id"
].int64_param = self.correlation_id
self._set_model_infer_request_inputs(remote_request)
self._set_model_infer_request_parameters(remote_request)
return remote_request
class RemoteResponseSender:
response_counts: Counter = Counter()
def __init__(
self,
model_infer_request: ModelInferRequest,
request_plane: RequestPlane,
data_plane: DataPlane,
):
self._model_infer_request = model_infer_request
self._request_plane = request_plane
self._data_plane = data_plane
def create_response(self, **kwargs) -> RemoteInferenceResponse:
if "model_name" in kwargs:
kwargs.pop("model_name")
if "model_version" in kwargs:
kwargs.pop("model_version")
if "request_id" in kwargs:
kwargs.pop("request_id")
return RemoteInferenceResponse(
model_name=self._model_infer_request.model_name,
model_version=self._model_infer_request.model_version,
request_id=self._model_infer_request.id,
**kwargs,
)
async def send(
self, inference_response: Optional[RemoteInferenceResponse] = None, **kwargs
) -> None:
if inference_response is None:
inference_response = RemoteInferenceResponse(
model_name=self._model_infer_request.model_name,
model_version=self._model_infer_request.model_version,
request_id=self._model_infer_request.id,
**kwargs,
)
await self._request_plane.post_response(
self._model_infer_request,
inference_response.to_model_infer_response(self._data_plane),
)
if inference_response.final:
RemoteResponseSender.response_counts[
(
self._model_infer_request.model_name,
self._model_infer_request.model_version,
)
] += 1
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for receiving inference responses to Triton Inference Server Models"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.protos.icp_pb2 import ModelInferResponse
if TYPE_CHECKING:
from triton_distributed.worker.remote_request import RemoteInferenceRequest
import uuid
from triton_distributed.icp.request_plane import (
get_icp_component_id,
get_icp_final_response,
get_icp_response_error,
set_icp_final_response,
set_icp_response_error,
)
from triton_distributed.worker.remote_tensor import RemoteTensor
from tritonserver import InternalError, Tensor, TritonError
from tritonserver._api._response import InferenceResponse
class AsyncRemoteResponseIterator:
"""Asyncio compatible response iterator
Response iterators are returned from model inference methods and
allow users to process inference responses in the order they were
received for a request.
"""
def __init__(
self,
data_plane: DataPlane,
request: RemoteInferenceRequest,
user_queue: Optional[asyncio.Queue] = None,
raise_on_error: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Initialize AsyncResponseIterator
AsyncResponseIterator objects are obtained from Model inference
methods and not instantiated directly. See `Model` documentation.
Parameters
----------
model : Model
model associated with inference request
request : TRITONSERVER_InferenceRequest
Underlying C binding TRITONSERVER_InferenceRequest
object. Private.
user_queue : Optional[asyncio.Queue]
Optional user queue for responses in addition to internal
iterator queue.
raise_on_error : bool
if True response errors will be raised as exceptions.
loop : Optional[asyncio.AbstractEventLoop]
asyncio loop object
"""
if loop is None:
loop = asyncio.get_running_loop()
self._loop = loop
self._queue: asyncio.Queue = asyncio.Queue()
self._user_queue = user_queue
self._complete = False
self._request = request
self._data_plane = data_plane
self._raise_on_error = raise_on_error
def __aiter__(self) -> AsyncRemoteResponseIterator:
"""Return async iterator. For use with async for loops.
Returns
-------
AsyncResponseIterator
Examples
--------
responses = server.model("test").async_infer(inputs={"fp16_input":numpy.array([[1]],dtype=numpy.float16)})
async for response in responses:
print(nummpy.from_dlpack(response.outputs["fp16_output"]))
"""
return self
async def __anext__(self):
"""Returns the next response received for a request
Returns the next response received for a request as an
awaitable object.
Raises
------
response.error
If raise_on_error is set to True, response errors are
raised as exceptions
StopAsyncIteration
Indicates all responses for a request have been received.
Final responses may or may not include outputs and must be
checked.
"""
if self._complete:
raise StopAsyncIteration
response = await self._queue.get()
self._complete = response.final
if response.error is not None and self._raise_on_error:
raise response.error
return response
def cancel(self) -> None:
"""Cancel an inflight request
Cancels an in-flight request. Cancellation is handled on a
best effort basis and may not prevent execution of a request
if it is already started or completed.
See c:func:`TRITONSERVER_ServerInferenceRequestCancel`
Examples
--------
responses = server.model("test").infer(inputs={"text_input":["hello"]})
responses.cancel()
"""
if self._request is not None:
self._request.cancel()
def _response_handler(self, response: ModelInferResponse):
try:
if self._request is None:
raise InternalError("Response received after final response flag")
final = False
if response is None or get_icp_final_response(response):
final = True
remote_response = RemoteInferenceResponse.from_model_infer_response(
self._request, response, self._data_plane, final
)
asyncio.run_coroutine_threadsafe(
self._queue.put(remote_response), self._loop
)
if self._user_queue is not None:
asyncio.run_coroutine_threadsafe(
self._user_queue.put(remote_response), self._loop
)
if final:
del self._request
self._request = None
except Exception as e:
message = f"Catastrophic failure in response callback: {e}"
print(message)
# catastrophic failure
raise e from None
@dataclass
class RemoteInferenceResponse:
"""Dataclass representing an inference response.
Inference response objects are returned from response iterators
which are in turn returned from model inference methods. They
contain output data, output parameters, any potential errors
reported and a flag to indicate if the response is the final one
for a request.
See c:func:`TRITONSERVER_InferenceResponse` for more details
Parameters
----------
model : Model
Model instance associated with the response.
request_id : Optional[str], default None
Unique identifier for the inference request (if provided)
parameters : dict[str, str | int | bool], default {}
Additional parameters associated with the response.
outputs : dict [str, Tensor], default {}
Output tensors for the inference.
error : Optional[TritonError], default None
Error (if any) that occurred in the processing of the request.
classification_label : Optional[str], default None
Classification label associated with the inference. Not currently supported.
final : bool, default False
Flag indicating if the response is final
"""
model_name: str
model_version: int
component_id: Optional[uuid.UUID] = None
request_id: Optional[str] = None
parameters: dict[str, str | int | bool] = field(default_factory=dict)
outputs: dict[str, RemoteTensor | Tensor] = field(default_factory=dict)
store_outputs_in_response: set[str] = field(default_factory=set)
error: Optional[TritonError] = None
classification_label: Optional[str] = None
final: bool = False
def _set_parameters_from_model_infer_response_parameters(
self, response: ModelInferResponse
):
for name, value in response.parameters.items():
if value.HasField("string_param"):
self.parameters[name] = value.string_param
elif value.HasField("int64_param"):
self.parameters[name] = value.int64_param
elif value.HasField("double_param"):
self.parameters[name] = value.double_param
elif value.HasField("bool_param"):
self.parameters[name] = value.bool_param
def _set_model_infer_response_outputs(
self, response: ModelInferResponse, data_plane: DataPlane
):
for name, value in self.outputs.items():
if not isinstance(value, RemoteTensor):
if not isinstance(value, Tensor):
tensor = Tensor._from_object(value)
else:
tensor = value
use_tensor_contents = name in self.store_outputs_in_response
remote_output = data_plane.put_output_tensor(
tensor, use_tensor_contents=use_tensor_contents
)
else:
remote_output = data_plane.create_output_tensor_reference(
value.remote_tensor
)
remote_output.name = name
response.outputs.append(remote_output)
def _set_model_infer_response_parameters(self, response: ModelInferResponse):
for key, value in self.parameters.items():
remote_value = response.parameters[key]
if isinstance(value, str):
remote_value.string_param = value
elif isinstance(value, int):
remote_value.int64_param = value
elif isinstance(value, float):
remote_value.double_param = value
elif isinstance(value, bool):
remote_value.bool_param = value
def to_model_infer_response(self, data_plane: DataPlane):
remote_response = ModelInferResponse()
remote_response.model_name = self.model_name
remote_response.model_version = str(self.model_version)
if self.request_id:
remote_response.id = self.request_id
if self.error:
set_icp_response_error(remote_response, self.error)
if self.final:
set_icp_final_response(remote_response, self.final)
self._set_model_infer_response_parameters(remote_response)
self._set_model_infer_response_outputs(remote_response, data_plane)
return remote_response
@staticmethod
def from_local_response(
local_response: InferenceResponse, store_outputs_in_response: bool = False
):
result = RemoteInferenceResponse(
local_response.model.name,
local_response.model.version,
None,
local_response.request_id,
final=local_response.final,
)
for tensor_name, tensor_value in local_response.outputs.items():
result.outputs[tensor_name] = tensor_value
if store_outputs_in_response:
result.store_outputs_in_response.add(tensor_name)
for parameter_name, parameter_value in local_response.parameters.items():
result.parameters[parameter_name] = parameter_value
result.error = local_response.error
return result
@staticmethod
def from_model_infer_response(
request: RemoteInferenceRequest,
response: ModelInferResponse,
data_plane: DataPlane,
final_response: bool,
) -> RemoteInferenceResponse:
result = RemoteInferenceResponse(
request.model_name,
request.model_version,
None,
request.request_id,
final=final_response,
)
if response is None:
return result
result.request_id = response.id
result.component_id = get_icp_component_id(response)
outputs = {}
for output in response.outputs:
outputs[output.name] = RemoteTensor(output, data_plane)
result.outputs = outputs
result._set_parameters_from_model_infer_response_parameters(response)
result.error = get_icp_response_error(response)
return result
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Sequence
import cupy
from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp.data_plane import (
DataPlane,
get_icp_data_type,
get_icp_memory_type,
get_icp_shape,
get_icp_tensor_size,
)
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from tritonserver import DataType, InvalidArgumentError, MemoryType, Tensor
# TODO
# Export from tritonserver
from tritonserver._api._dlpack import DLDeviceType
from tritonserver._api._tensor import DeviceOrMemoryType
# Run cupy's cuda.is_available once to
# avoid the exception hitting runtime code.
try:
cupy.cuda.is_available()
except CUDARuntimeError:
pass
@dataclass
class RemoteTensor:
remote_tensor: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor
data_plane: DataPlane
_local_tensor: Optional[Tensor] = None
# FIXME: This is a hack to avoid double deletion of the tensor
# Tensor must be explicitly released by the user before data plane connection is closed
deleted: bool = False
@property
def data_type(self) -> DataType | None:
return get_icp_data_type(self.remote_tensor)
@property
def shape(self) -> Sequence[int] | None:
return get_icp_shape(self.remote_tensor)
@property
def memory_type(self) -> MemoryType | None:
return get_icp_memory_type(self.remote_tensor)
@property
def size(self) -> int | None:
return get_icp_tensor_size(self.remote_tensor)
@property
def local_tensor(self) -> Tensor:
if not self._local_tensor:
self._local_tensor = self.data_plane.get_tensor(self.remote_tensor)
if self._local_tensor is None:
raise InvalidArgumentError("Not able to resolve Tensor locally")
return self._local_tensor
@property
def data_ptr(self) -> int:
return self.local_tensor.data_ptr
def __dlpack__(self, *, stream=None):
return self.local_tensor.__dlpack__(stream=stream)
def __dlpack_device__(self) -> tuple[DLDeviceType, int]:
return self.local_tensor.__dlpack_device__()
def to_string_array(self):
return self.local_tensor.to_string_array()
def to_bytes_array(self):
return self.local_tensor.to_bytes_array()
def to_host(self) -> Tensor:
return self.local_tensor.to_host()
def to_device(self, device: DeviceOrMemoryType) -> Tensor:
return self.local_tensor.to_device(device)
def __del__(self):
# FIXME: This is a hack to avoid double deletion of the tensor
# Tensor must be explicitly released by the user before data plane connection is closed
if not self.deleted:
self.data_plane.release_tensor(self.remote_tensor)
self.deleted = True
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import logging
import os
import uuid
from typing import Optional
from google.protobuf import json_format, text_format
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.worker.operator import Operator
from triton_distributed.worker.remote_request import RemoteInferenceRequest
from triton_distributed.worker.remote_response import RemoteInferenceResponse
from tritonclient.grpc import model_config_pb2
from tritonserver import InvalidArgumentError, Server
class TritonCoreOperator(Operator):
def __init__(
self,
name: str,
version: int,
triton_core: Server,
request_plane: RequestPlane,
data_plane: DataPlane,
parameters: dict,
repository: Optional[str] = None,
logger: logging.Logger = logging.getLogger(),
):
self._repository = repository
self._name = name
self._parameters = parameters
self._triton_core = triton_core
self._version = version
self._logger = logger
self._request_plane = request_plane
self._data_plane = data_plane
self._store_outputs_in_response = self._parameters.get(
"store_outputs_in_response", False
)
if not self._repository:
self._repository = "."
if repository:
triton_core.register_model_repository(repository)
parameter_config = self._parameters.get("config", None)
model_config = None
try:
model_config_path = os.path.join(
self._repository, self._name, "config.pbtxt"
)
with open(model_config_path, "r") as config_file:
model_config = text_format.Parse(
config_file.read(), model_config_pb2.ModelConfig()
)
except Exception:
pass
if parameter_config and model_config:
model_config.MergeFrom(
json_format.Parse(
json.dumps(parameter_config), model_config_pb2.ModelConfig()
)
)
model_config = {"config": json_format.MessageToJson(model_config)}
elif parameter_config:
model_config = {"config": parameter_config}
else:
model_config = None
self._local_model = self._triton_core.load(self._name, model_config)
async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
request_id_map = {}
response_queue: asyncio.Queue = asyncio.Queue()
for request in requests:
self._logger.info("\n\nReceived request: \n\n%s\n\n", request)
try:
local_request = request.to_local_request(self._local_model)
except Exception as e:
message = f"Can't resolve tensors for request, ignoring request,{e}"
self._logger.error(message)
await request.response_sender().send(
error=InvalidArgumentError(message), final=True
)
continue
request_id = str(uuid.uuid1())
original_id = None
if local_request.request_id is not None:
original_id = local_request.request_id
local_request.request_id = request_id
request_id_map[request_id] = (request.response_sender(), original_id)
local_request.response_queue = response_queue
self._local_model.async_infer(local_request)
while request_id_map:
local_response = await response_queue.get()
remote_response = RemoteInferenceResponse.from_local_response(
local_response, self._store_outputs_in_response
)
response_sender, original_id = request_id_map[local_response.request_id]
remote_response.request_id = original_id
if local_response.final:
del request_id_map[local_response.request_id]
self._logger.info("\n\nSending response\n\n%s\n\n", remote_response)
await response_sender.send(remote_response)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import importlib
import logging
import multiprocessing
import os
import pathlib
import signal
import sys
import uuid
from collections import Counter
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, Type
import tritonserver
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.worker.log_formatter import LOGGER_NAME, setup_logger
from triton_distributed.worker.operator import Operator, OperatorConfig
from triton_distributed.worker.remote_request import (
RemoteInferenceRequest,
RemoteResponseSender,
)
if TYPE_CHECKING:
import uvicorn
logger = logging.getLogger(LOGGER_NAME)
@dataclass
class WorkerConfig:
request_plane: Type[RequestPlane] = NatsRequestPlane
data_plane: Type[DataPlane] = UcpDataPlane
request_plane_args: tuple[list, dict] = field(default_factory=lambda: ([], {}))
data_plane_args: tuple[list, dict] = field(default_factory=lambda: ([], {}))
log_level: int = 0
operators: list[OperatorConfig] = field(default_factory=list)
triton_log_path: Optional[str] = None
name: str = str(uuid.uuid1())
log_dir: Optional[str] = None
metrics_port: int = 0
class Worker:
def __init__(
self, config: Optional[WorkerConfig] = None, **kwargs #: Unpack[WorkerConfig]
) -> None:
if config is None:
config = WorkerConfig(**kwargs)
self._request_plane = config.request_plane(
*config.request_plane_args[0], **config.request_plane_args[1]
)
self._data_plane = config.data_plane(
*config.data_plane_args[0], **config.data_plane_args[1]
)
self._triton_log_path = config.triton_log_path
self._name = config.name
self._log_level = config.log_level
self._operator_configs = config.operators
self._log_dir = config.log_dir
self._stop_requested = False
self._requests_received: Counter = Counter()
self._background_tasks: dict[object, set] = {}
self._completion_conds: dict[object, asyncio.Condition] = {}
self._inflight_requests: dict[object, int] = {}
self._max_inflght_requests: dict[object, int] = {}
self._operators: dict[tuple[str, int], Operator] = {}
self._metrics_port = config.metrics_port
self._metrics_server: Optional[uvicorn.Server] = None
def _import_operators(self):
for operator_config in self._operator_configs:
if operator_config.repository:
repository_path = pathlib.Path(operator_config.repository)
sys.path.append(str(repository_path.absolute()))
else:
repository_path = pathlib.Path(".")
if isinstance(operator_config.implementation, str):
split_workflow = operator_config.implementation.split(":")
module_name = ":".join(split_workflow[:-1])
class_name = split_workflow[-1]
module_path = pathlib.Path(module_name)
parent_paths = list(module_path.parents)
root_parent = pathlib.Path(".")
if parent_paths:
root_parent = parent_paths[-1]
if root_parent == pathlib.Path("."):
module_path = repository_path.joinpath(module_path)
if str(module_path.parent.absolute()) not in sys.path:
sys.path.append(str(module_path.parent.absolute()))
try:
module = importlib.import_module(module_path.name)
class_ = getattr(module, class_name)
except Exception as e:
logger.exception(
"can't instantiate operator: %s %s", operator_config.name, e
)
raise e
elif issubclass(operator_config.implementation, Operator):
class_ = operator_config.implementation
else:
logger.exception(
"can't instantiate operator: %s",
operator_config.name,
)
raise Exception("invalid implementation type")
try:
if operator_config.log_level is None:
operator_config.log_level = self._log_level
operator_logger = setup_logger(
log_level=operator_config.log_level,
logger_name=f"OPERATOR{(operator_config.name,operator_config.version)}",
)
operator = class_(
operator_config.name,
operator_config.version,
self._triton_core,
self._request_plane,
self._data_plane,
operator_config.parameters,
operator_config.repository,
operator_logger,
)
except Exception as e:
logger.exception(
"can't instantiate operator: %s %s", operator_config.name, e
)
raise e
operator_key = (operator_config.name, operator_config.version)
self._operators[operator_key] = operator
self._max_inflght_requests[operator] = operator_config.max_inflight_requests
self._inflight_requests[operator] = 0
self._background_tasks[operator] = set()
self._completion_conds[operator] = asyncio.Condition()
async def _process_request(self, request):
logger.info("\n\nserver received request: \n\n%s\n\n", request)
operator_key = (request.model_name, int(request.model_version))
if operator_key in self._operators:
operator = self._operators[operator_key]
self._requests_received[operator] += 1
remote_request = RemoteInferenceRequest.from_model_infer_request(
request, self._data_plane, self._request_plane
)
await operator.execute([remote_request])
else:
logger.warn("Received request for unknown operator")
async def _process_request_task(self, operator, name, version):
requests = await self._request_plane.pull_requests(name, str(version))
# When the request is received, notify the handler to
# pull next requests if capacity permits.
async with self._completion_conds[operator]:
self._inflight_requests[operator] += 1
logger.debug(f"{operator} inflight: {self._inflight_requests[operator]}")
self._completion_conds[operator].notify()
# Process request received from the request plane
async for request in requests:
await self._process_request(request)
# The request is processed and new requests may be
# pulled.
async with self._completion_conds[operator]:
self._inflight_requests[operator] -= 1
logger.debug(f"{operator} inflight {self._inflight_requests[operator]}")
self._completion_conds[operator].notify()
async def _add_process_request_task(self, operator, name, version):
task = asyncio.create_task(self._process_request_task(operator, name, version))
self._background_tasks[operator].add(task)
task.add_done_callback(self._background_tasks[operator].discard)
async def _request_handler(self, operator, name, version):
while not self._stop_requested:
async with self._completion_conds[operator]:
# TODO: Instead of pulling a fixed number of requests try
# querying the model status to understand whether or not
# to pull more requests.
if (
self._inflight_requests[operator]
< self._max_inflght_requests[operator]
):
await self._add_process_request_task(operator, name, version)
# Block the handler till task is notified
# We want to create new tasks only when they
# are needed so that at a given time, there
# is only a single model task pulling from the
# request plane.
await self._completion_conds[operator].wait()
async def _initialize_request_handlers(self):
handlers = []
for (name, version), operator in self._operators.items():
logger.info(f"Starting {name} handler...")
handlers.append(self._request_handler(operator, name, version))
await asyncio.gather(*handlers)
async def serve(self):
self._triton_core = tritonserver.Server(
model_repository=".",
log_error=True,
log_verbose=self._log_level,
strict_model_config=False,
model_control_mode=tritonserver.ModelControlMode.EXPLICIT,
log_file=self._triton_log_path,
).start(wait_until_ready=True)
try:
await self._request_plane.connect()
except Exception as e:
logger.exception(
"Encountered an error when trying to connect to request plane"
)
raise e
try:
self._data_plane.connect()
except Exception as e:
logger.exception(
"Encountered and error when trying to connect to data plane"
)
raise e
try:
self._import_operators()
logger.info("Worker started...")
await self._initialize_request_handlers()
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception("Encountered an error in worker: %s", e)
self._stop_requested = True
logger.info("worker store: %s", list(self._data_plane._tensor_store.keys()))
logger.info("Worker stopped...")
logger.info(
"Hosted Operators: %s Requests Received: %s Responses Sent: %s",
self._operators,
self._requests_received,
RemoteResponseSender.response_counts,
)
await self._request_plane.close()
self._data_plane.close()
if self._metrics_server:
self._metrics_server.should_exit = True
await self._metrics_server.shutdown()
async def shutdown(self, signal):
logger.info("Received exit signal %s...", signal.name)
self._stop_requested = True
try:
if self._data_plane:
self._data_plane.close()
except Exception as e:
logger.exception("Failed to close the data plane: %s", e)
try:
if self._request_plane:
await self._request_plane.close()
except Exception as e:
logger.exception("Failed to close the request plane: %s", e)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
logger.info("Cancelling %s outstanding tasks", len(tasks))
[task.cancel() for task in tasks]
self._triton_core.stop()
if self._metrics_server:
self._metrics_server.should_exit = True
await self._metrics_server.shutdown()
def _setup_metrics_server(self):
import uvicorn
from fastapi import FastAPI
from fastapi.responses import PlainTextResponse
app = FastAPI()
config = uvicorn.Config(app, port=self._metrics_port)
server = uvicorn.Server(config)
@app.get("/metrics", response_class=PlainTextResponse)
def metrics() -> str:
if self._triton_core:
return self._triton_core.metrics()
else:
return ""
return server
async def _wait_for_tasks(self, loop):
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
try:
await asyncio.gather(*tasks, return_exceptions=True)
except asyncio.CancelledError as e:
logger.exception("Cancelled in task clean-up: %s", e)
except Exception as e:
logger.exception("Encountered an error in task clean-up: %s", e)
logger.info("Stopping the event loop")
loop.stop()
def start(self):
if self._log_dir:
os.makedirs(self._log_dir, exist_ok=True)
stdout_path = os.path.join(self._log_dir, f"{self._name}.stdout.log")
stderr_path = os.path.join(self._log_dir, f"{self._name}.stderr.log")
if not self._triton_log_path:
self._triton_log_path = os.path.join(
self._log_dir, f"{self._name}.triton.log"
)
sys.stdout = open(stdout_path, "w", buffering=1)
sys.stderr = open(stderr_path, "w", buffering=1)
triton_log = open(self._triton_log_path, "w", buffering=1)
triton_log.close()
setup_logger(log_level=self._log_level)
loop = asyncio.get_event_loop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
# Note: mypy has known issues inferring
# types of lambdas that include capturing loop variables
for sig in signals:
loop.add_signal_handler(
sig, lambda s=sig: asyncio.create_task(self.shutdown(s)) # type: ignore
)
try:
if self._metrics_port:
loop.create_task(self.serve())
self._metrics_server = self._setup_metrics_server()
assert self._metrics_server, "Unable to start metrics server"
loop.run_until_complete(self._metrics_server.serve())
else:
loop.run_until_complete(self.serve())
except asyncio.CancelledError:
pass
logger.info("Worker cancelled!")
finally:
loop.run_until_complete(self._wait_for_tasks(loop))
loop.close()
logger.info("Successfully shutdown worker.")
sys.stdout.flush()
sys.stderr.flush()
if self._log_dir:
sys.stdout.close()
sys.stderr.close()
class Deployment:
def __init__(self, worker_configs: list[WorkerConfig]):
self._process_context = multiprocessing.get_context("spawn")
self._worker_configs = worker_configs
self._workers: list[multiprocessing.context.SpawnProcess] = []
@staticmethod
def _start_worker(worker_config):
Worker(worker_config).start()
def start(self):
for worker_config in self._worker_configs:
self._workers.append(
self._process_context.Process(
target=Deployment._start_worker,
name=worker_config.name,
args=[worker_config],
)
)
def shutdown(self, join=True, timeout=10):
for worker in self._workers:
worker.terminate()
if join:
for worker in self._workers:
worker.join(timeout)
for worker in self._workers:
if worker.is_alive():
worker.kill()
worker.join(timeout)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import multiprocessing
import signal
import subprocess
import sys
import time
import pytest
import pytest_asyncio
from triton_distributed.icp.nats_request_plane import NatsServer
from triton_distributed.worker.log_formatter import LOGGER_NAME, setup_logger
from triton_distributed.worker.worker import Worker
logger = logging.getLogger(LOGGER_NAME)
NATS_PORT = 4223
TEST_API_SERVER_MODEL_REPO_PATH = (
"/workspace/worker/python/tests/integration/api_server/models"
)
async def _wait_for_tasks(loop):
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
try:
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
print("Encountered an error in task clean-up: %s", e)
print("Stopping the event loop")
loop.stop()
def _run_worker(name, queue, worker_config):
tensor_store_keys = None
try:
with open(f"{name}.worker.stdout.log", "w") as output_:
with open(f"{name}.worker.stderr.log", "w") as output_err:
with open(f"{name}.worker.triton.log", "w"):
sys.stdout = output_
sys.stderr = output_err
triton_log_filename = f"{name}.worker.triton.log"
setup_logger(log_level=worker_config.log_level)
worker_config.triton_log_file = triton_log_filename
worker_config.name = name
try:
worker = Worker(worker_config)
except Exception as e:
queue.put(f"Failed to start {name}: {e}")
logger.exception("Failed to instantiate a worker class")
loop = asyncio.new_event_loop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for sig in signals:
loop.add_signal_handler(
sig, lambda s=sig: asyncio.create_task(worker.shutdown(s)) # type: ignore
)
try:
queue.put("READY")
loop.run_until_complete(worker.serve())
except asyncio.CancelledError:
print("server cancellation detected")
finally:
loop.run_until_complete(_wait_for_tasks(loop))
loop.close()
tensor_store_keys = list(
worker._data_plane._tensor_store.keys()
)
sys.exit(len(tensor_store_keys))
except Exception as e:
print(f"Worker Serving Failed to start: {e}")
queue.put(f"Failed to start {name}: {e}")
raise e
class WorkerManager:
ctx = multiprocessing.get_context("spawn")
@staticmethod
def setup_worker_process(operators, name, queue, worker_config):
worker_config.name = name
worker_config.operators = operators
process = WorkerManager.ctx.Process(
target=_run_worker,
args=(name, queue, worker_config),
name=name,
)
process.start()
return process
@staticmethod
def cleanup_workers(workers, check_status=True):
for worker in workers:
print(f"Terminating {worker.name} worker", flush=True)
worker.terminate()
for worker in workers:
worker.join()
print(f"{worker.name} exited with {worker.exitcode} stored tensors")
assert (
worker.exitcode == 0 if check_status else True
), f"{worker.name} exited with {worker.exitcode} stored tensors"
@pytest.fixture
def worker_manager():
return WorkerManager
@pytest.fixture(scope="session")
def nats_server():
server = NatsServer()
yield server
del server
@pytest.fixture(scope="session")
def api_server():
command = ["tritonserver", "--model-store", str(TEST_API_SERVER_MODEL_REPO_PATH)]
with open("api_server.stdout.log", "wt") as output_:
with open("api_server.stderr.log", "wt") as output_err:
process = subprocess.Popen(
command, stdin=subprocess.DEVNULL, stdout=output_, stderr=output_err
)
time.sleep(5)
yield process
process.terminate()
process.wait()
print("Successfully cleaned-up T2 API server")
@pytest_asyncio.fixture
async def aio_benchmark(benchmark):
async def run_async_coroutine(func, *args, **kwargs):
return await func(*args, **kwargs)
def _wrapper(func, *args, **kwargs):
if asyncio.iscoroutinefunction(func):
@benchmark
def _():
future = asyncio.ensure_future(
run_async_coroutine(func, *args, **kwargs)
)
return asyncio.get_event_loop().run_until_complete(future)
else:
benchmark(func, *args, **kwargs)
return _wrapper
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import numpy
from triton_distributed.worker import Operator, RemoteInferenceRequest, RemoteOperator
class AddMultiplyDivide(Operator):
def __init__(
self,
name,
version,
triton_core,
request_plane,
data_plane,
parameters,
repository,
logger,
):
self._triton_core = triton_core
self._request_plane = request_plane
self._data_plane = data_plane
self._parameters = parameters
self._add_model = RemoteOperator(
"add", 1, self._request_plane, self._data_plane
)
self._multiply_model = RemoteOperator(
"multiply", 1, self._request_plane, self._data_plane
)
self._divide_model = RemoteOperator(
"divide", 1, self._request_plane, self._data_plane
)
async def execute(self, requests: list[RemoteInferenceRequest]):
print("in execute!", flush=True)
for request in requests:
outputs = {}
print(request.inputs, flush=True)
array = None
try:
array = numpy.from_dlpack(request.inputs["int64_input"])
except Exception as e:
print(e)
print(array)
response = [
response
async for response in await self._add_model.async_infer(
inputs={"int64_input": array}
)
][0]
print(response, flush=True)
for output_name, output_value in response.outputs.items():
outputs[f"{response.model_name}_{output_name}"] = output_value
addition_output_partial = response.outputs["int64_output_partial"]
addition_output_total = response.outputs["int64_output_total"]
multiply_respnoses = self._multiply_model.async_infer(
inputs={"int64_input": addition_output_partial}, raise_on_error=False
)
divide_responses = self._divide_model.async_infer(
inputs={
"int64_input": addition_output_partial,
"int64_input_divisor": addition_output_total,
},
raise_on_error=False,
)
error = None
for result in asyncio.as_completed([multiply_respnoses, divide_responses]):
responses = await result
async for response in responses:
print("response!", response, flush=True)
print("error!", response.error, flush=True)
if response.error is not None:
error = response.error
break
for output_name, output_value in response.outputs.items():
outputs[f"{response.model_name}_{output_name}"] = output_value
if error is not None:
await request.response_sender().send(error=error, final=True)
else:
await request.response_sender().send(outputs=outputs, final=True)
for output in outputs.values():
del output
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
from triton_distributed.worker import Operator, RemoteInferenceRequest
class Identity(Operator):
"""
This is a dummy workflow that sends a single input as an output.
"""
def __init__(
self,
name,
version,
triton_core,
request_plane,
data_plane,
params,
repository,
logger,
):
self._triton_core = triton_core
self._request_plane = request_plane
self._data_plane = data_plane
self._params = params
async def execute(self, requests: list[RemoteInferenceRequest]):
for request in requests:
try:
array = numpy.from_dlpack(request.inputs["input"])
except Exception as e:
print(e)
await request.response_sender().send(final=True, error=e)
return
outputs: dict[str, numpy.ndarray] = {"output": array}
store_outputs_in_response = False
if "store_outputs_in_response" in self._params:
store_outputs_in_response = self._params["store_outputs_in_response"]
store_outputs_in_response_set = set()
if store_outputs_in_response:
store_outputs_in_response_set.add("output")
await request.response_sender().send(
outputs=outputs,
final=True,
store_outputs_in_response=store_outputs_in_response_set,
)
for output in outputs.values():
del output
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from triton_distributed.worker.operator import Operator
from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.remote_request import RemoteInferenceRequest
from tritonserver import TritonError
class MockDisaggregatedServing(Operator):
def __init__(
self,
name,
version,
triton_core,
request_plane,
data_plane,
params,
repository,
logger,
):
self._triton_core = triton_core
self._request_plane = request_plane
self._data_plane = data_plane
self._params = params
self._preprocessing_model = RemoteOperator(
"preprocessing", 1, self._request_plane, self._data_plane
)
self._context_model = RemoteOperator(
"context", 1, self._request_plane, self._data_plane
)
self._generate_model = RemoteOperator(
"generation", 1, self._request_plane, self._data_plane
)
self._postprocessing_model = RemoteOperator(
"postprocessing", 1, self._request_plane, self._data_plane
)
self._logger = logger
async def _run_generate(self, context_response, response_sender):
error = None
generate_inputs = {}
if not error:
for output_name in ["KV_CACHE", "REQUEST_OUTPUT_LEN"]:
if output_name not in context_response.outputs.keys():
error_msg = f"Expected '{output_name}' as output in llm model response, Got outputs {context_response.outputs.keys()}"
self._logger.error(error_msg)
self._logger.debug(f"context_response: {context_response}")
error = TritonError(error_msg)
else:
generate_inputs[output_name] = context_response.outputs[output_name]
postproc_result = []
generate_responses = []
if not error:
try:
# TODO: Run post-processing in parallel with generate
async for response in await self._generate_model.async_infer(
inputs=generate_inputs
):
generate_responses.append(response)
self._logger.debug(f"Received response {response}")
if not generate_responses[-1].final:
postproc_result.append(
await self._run_postprocessing(
generate_responses[-1], response_sender, final=False
)
)
except Exception as e:
error = TritonError(repr(e))
self._logger.exception("Failed to run post-processing")
for generate_response in generate_responses:
for tensor in generate_response.outputs.values():
del tensor
return postproc_result
async def _run_postprocessing(self, llm_response, response_sender, final):
self._logger.debug(f"going to run_post_processing final={final}")
postproc_inputs = {}
for output_name in ["OUTPUT_IDS", "SEQUENCE_LENGTH"]:
if output_name not in llm_response.outputs.keys():
error_msg = f"Expected '{output_name}' as output in llm model response, Got outputs {llm_response.outputs.items()}"
self._logger.error(error_msg)
self._logger.debug(f"llm_response: {llm_response}")
raise Exception(error_msg)
else:
postproc_inputs[output_name] = llm_response.outputs[output_name]
outputs = {}
postproc_responses = []
# TODO: Run post-processing in parallel with generate
self._logger.debug(f"Sending request to post-process {postproc_inputs}")
sending = []
async for response in await self._postprocessing_model.async_infer(
inputs=postproc_inputs
):
self._logger.debug(f"Received response {response}")
self._logger.debug(f"Got response from post-process {response}")
postproc_responses.append(response)
outputs["output"] = postproc_responses[-1].outputs["OUTPUT"]
sending.append(await response_sender().send(outputs=outputs, final=False))
return sending
async def execute(self, requests: list[RemoteInferenceRequest]):
print("in execute!", flush=True)
error = None
for request in requests:
"""
Pre-processing
"""
preproc_responses = []
async for response in await self._preprocessing_model.async_infer(
inference_request=request
):
preproc_responses.append(response)
if not error and len(preproc_responses) != 1:
error_msg = f"Expected exactly 1 response from preprocessing model, Got {len(preproc_responses)}"
self._logger.error(error_msg)
error = TritonError(error_msg)
context_inputs = {}
if not error:
for output_name in ["INPUT_IDS", "INPUT_LENGTH", "REQUEST_OUTPUT_LEN"]:
if output_name not in preproc_responses[0].outputs.keys():
error_msg = f"Expected '{output_name}' as output in preprocessing model response, Got outputs {response.outputs.keys()}"
self._logger.error(error_msg)
error = TritonError(error_msg)
else:
context_inputs[output_name] = preproc_responses[0].outputs[
output_name
]
"""
Prefill
"""
context_responses = []
postproc_result = []
if not error:
async for response in await self._context_model.async_infer(
inputs=context_inputs
):
context_responses.append(response)
if not error:
if not error and len(context_responses) != 1:
error_msg = f"Expected exactly 1 response from context model, Got {len(context_responses)}"
self._logger.error(error_msg)
error = TritonError(error_msg)
else:
postproc_result.append(
self._run_postprocessing(
context_responses[0], request.response_sender, final=False
)
)
"""
Generate
"""
if not error:
postproc_result.append(
self._run_generate(context_responses[0], request.response_sender)
)
for result in postproc_result:
try:
await result
except Exception as e:
self._logger.exception(
f"Failed getting response from post-processing {result}: {e}"
)
error = TritonError(repr(e))
for tensor in preproc_responses[0].outputs.values():
del tensor
for tensor in context_responses[0].outputs.values():
del tensor
if error:
await request.response_sender().send(error=error, final=True)
else:
await request.response_sender().send(final=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment