Commit ccfe101e authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

build: pip based installation of icp and runtime. Also make tritonserver optional


Signed-off-by: default avatarNeelay Shah <neelays@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 5701753a
...@@ -19,7 +19,13 @@ import abc ...@@ -19,7 +19,13 @@ import abc
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Optional, Type from typing import Any, Optional, Type
from tritonserver import Server try:
from tritonserver import Server as TritonCore
TRITON_CORE_AVAILABLE = True
except ImportError:
TRITON_CORE_AVAILABLE = False
TritonCore = type(None) # type: ignore[misc,assignment]
from triton_distributed.icp.data_plane import DataPlane from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane from triton_distributed.icp.request_plane import RequestPlane
...@@ -32,7 +38,6 @@ class Operator(abc.ABC): ...@@ -32,7 +38,6 @@ class Operator(abc.ABC):
self, self,
name: str, name: str,
version: int, version: int,
triton_core: Server,
request_plane: RequestPlane, request_plane: RequestPlane,
data_plane: DataPlane, data_plane: DataPlane,
parameters: Optional[dict[str, str | int | bool | bytes]] = field( parameters: Optional[dict[str, str | int | bool | bytes]] = field(
...@@ -40,6 +45,7 @@ class Operator(abc.ABC): ...@@ -40,6 +45,7 @@ class Operator(abc.ABC):
), ),
repository: Optional[str] = None, repository: Optional[str] = None,
logger: Optional[Any] = None, logger: Optional[Any] = None,
triton_core: Optional[TritonCore] = None,
): ):
pass pass
......
...@@ -19,8 +19,6 @@ import asyncio ...@@ -19,8 +19,6 @@ import asyncio
import uuid import uuid
from typing import Optional from typing import Optional
from tritonserver import InvalidArgumentError
from triton_distributed.icp.data_plane import DataPlane from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.runtime.remote_request import RemoteInferenceRequest from triton_distributed.runtime.remote_request import RemoteInferenceRequest
...@@ -89,16 +87,14 @@ class RemoteOperator: ...@@ -89,16 +87,14 @@ class RemoteOperator:
inference_request.model_name = self.name inference_request.model_name = self.name
inference_request.model_version = self.version inference_request.model_version = self.version
if inference_request.data_plane != self.data_plane: if inference_request.data_plane != self.data_plane:
raise InvalidArgumentError( raise ValueError(
"Data plane mismatch between remote request and remote operator: \n\n Operator: {self.data_plane} \n\n Request: {inference_request.data_plane}" "Data plane mismatch between remote request and remote operator: \n\n Operator: {self.data_plane} \n\n Request: {inference_request.data_plane}"
) )
if (inference_request.response_queue is not None) and ( if (inference_request.response_queue is not None) and (
not isinstance(inference_request.response_queue, asyncio.Queue) not isinstance(inference_request.response_queue, asyncio.Queue)
): ):
raise InvalidArgumentError( raise ValueError("asyncio.Queue must be used for async response iterator")
"asyncio.Queue must be used for async response iterator"
)
response_iterator = AsyncRemoteResponseIterator( response_iterator = AsyncRemoteResponseIterator(
self._data_plane, self._data_plane,
inference_request, inference_request,
......
...@@ -24,12 +24,10 @@ from collections import Counter ...@@ -24,12 +24,10 @@ from collections import Counter
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Optional from typing import Any, Optional
import tritonserver
from tritonserver import InferenceRequest, InvalidArgumentError, Tensor
from triton_distributed.icp.data_plane import DataPlane from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest
from triton_distributed.icp.request_plane import RequestPlane, get_icp_component_id from triton_distributed.icp.request_plane import RequestPlane, get_icp_component_id
from triton_distributed.icp.tensor import Tensor
from triton_distributed.runtime.remote_response import RemoteInferenceResponse from triton_distributed.runtime.remote_response import RemoteInferenceResponse
from triton_distributed.runtime.remote_tensor import RemoteTensor from triton_distributed.runtime.remote_tensor import RemoteTensor
...@@ -51,16 +49,6 @@ class RemoteInferenceRequest: ...@@ -51,16 +49,6 @@ class RemoteInferenceRequest:
parameters: dict[str, str | int | bool | float] = field(default_factory=dict) parameters: dict[str, str | int | bool | float] = field(default_factory=dict)
response_queue: Optional[queue.SimpleQueue | asyncio.Queue] = None response_queue: Optional[queue.SimpleQueue | asyncio.Queue] = None
def _set_local_request_inputs(self, local_request: tritonserver.InferenceRequest):
for input_name, remote_tensor in self.inputs.items():
local_request.inputs[input_name] = remote_tensor.local_tensor
def _set_local_request_parameters(
self, local_request: tritonserver.InferenceRequest
):
for parameter_name, parameter_value in self.parameters.items():
local_request.parameters[parameter_name] = parameter_value
def _set_model_infer_request_inputs( def _set_model_infer_request_inputs(
self, self,
remote_request: ModelInferRequest, remote_request: ModelInferRequest,
...@@ -127,7 +115,7 @@ class RemoteInferenceRequest: ...@@ -127,7 +115,7 @@ class RemoteInferenceRequest:
def response_sender(self): def response_sender(self):
if self._request_plane is None or self._model_infer_request is None: if self._request_plane is None or self._model_infer_request is None:
raise InvalidArgumentError( raise ValueError(
"Response only valid for requests received from request plane" "Response only valid for requests received from request plane"
) )
return RemoteResponseSender( return RemoteResponseSender(
...@@ -161,23 +149,6 @@ class RemoteInferenceRequest: ...@@ -161,23 +149,6 @@ class RemoteInferenceRequest:
RemoteInferenceRequest._set_parameters_from_model_infer_request(result, request) RemoteInferenceRequest._set_parameters_from_model_infer_request(result, request)
return result return result
def to_local_request(self, model: tritonserver.Model) -> InferenceRequest:
local_request = model.create_request()
if self.request_id is not None:
local_request.request_id = self.request_id
if self.priority is not None:
local_request.priority = self.priority
if self.timeout is not None:
local_request.timeout = self.timeout
if self.correlation_id is not None:
local_request.correlation_id = self.correlation_id
self._set_local_request_inputs(local_request)
self._set_local_request_parameters(local_request)
return local_request
def to_model_infer_request(self) -> ModelInferRequest: def to_model_infer_request(self) -> ModelInferRequest:
remote_request = ModelInferRequest() remote_request = ModelInferRequest()
remote_request.model_name = self.model_name remote_request.model_name = self.model_name
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Class for receiving inference responses to Triton Inference Server Models""" """Class for receiving inference responses to Triton Distributed Operators"""
from __future__ import annotations from __future__ import annotations
...@@ -27,18 +27,22 @@ from triton_distributed.icp.protos.icp_pb2 import ModelInferResponse ...@@ -27,18 +27,22 @@ from triton_distributed.icp.protos.icp_pb2 import ModelInferResponse
if TYPE_CHECKING: if TYPE_CHECKING:
from triton_distributed.runtime.remote_request import RemoteInferenceRequest from triton_distributed.runtime.remote_request import RemoteInferenceRequest
import uuid try:
from tritonserver import Tensor as TritonTensor
except ImportError:
TritonTensor = type(None) # type: ignore [misc, assignment]
from tritonserver import InternalError, Tensor, TritonError import uuid
from tritonserver._api._response import InferenceResponse
from triton_distributed.icp.request_plane import ( from triton_distributed.icp.request_plane import (
RequestPlaneError,
get_icp_component_id, get_icp_component_id,
get_icp_final_response, get_icp_final_response,
get_icp_response_error, get_icp_response_error,
set_icp_final_response, set_icp_final_response,
set_icp_response_error, set_icp_response_error,
) )
from triton_distributed.icp.tensor import Tensor
from triton_distributed.runtime.logger import get_logger from triton_distributed.runtime.logger import get_logger
from triton_distributed.runtime.remote_tensor import RemoteTensor from triton_distributed.runtime.remote_tensor import RemoteTensor
...@@ -107,7 +111,7 @@ class AsyncRemoteResponseIterator: ...@@ -107,7 +111,7 @@ class AsyncRemoteResponseIterator:
responses = server.model("test").async_infer(inputs={"fp16_input":numpy.array([[1]],dtype=numpy.float16)}) responses = server.model("test").async_infer(inputs={"fp16_input":numpy.array([[1]],dtype=numpy.float16)})
async for response in responses: async for response in responses:
print(nummpy.from_dlpack(response.outputs["fp16_output"])) print(numpy.from_dlpack(response.outputs["fp16_output"]))
""" """
...@@ -165,7 +169,7 @@ class AsyncRemoteResponseIterator: ...@@ -165,7 +169,7 @@ class AsyncRemoteResponseIterator:
def _response_handler(self, response: ModelInferResponse): def _response_handler(self, response: ModelInferResponse):
try: try:
if self._request is None: if self._request is None:
raise InternalError("Response received after final response flag") raise ValueError("Response received after final response flag")
final = False final = False
...@@ -203,8 +207,6 @@ class RemoteInferenceResponse: ...@@ -203,8 +207,6 @@ class RemoteInferenceResponse:
reported and a flag to indicate if the response is the final one reported and a flag to indicate if the response is the final one
for a request. for a request.
See c:func:`TRITONSERVER_InferenceResponse` for more details
Parameters Parameters
---------- ----------
model : Model model : Model
...@@ -215,7 +217,7 @@ class RemoteInferenceResponse: ...@@ -215,7 +217,7 @@ class RemoteInferenceResponse:
Additional parameters associated with the response. Additional parameters associated with the response.
outputs : dict [str, Tensor], default {} outputs : dict [str, Tensor], default {}
Output tensors for the inference. Output tensors for the inference.
error : Optional[TritonError], default None error : Optional[RequestPlaneError], default None
Error (if any) that occurred in the processing of the request. Error (if any) that occurred in the processing of the request.
classification_label : Optional[str], default None classification_label : Optional[str], default None
Classification label associated with the inference. Not currently supported. Classification label associated with the inference. Not currently supported.
...@@ -231,7 +233,7 @@ class RemoteInferenceResponse: ...@@ -231,7 +233,7 @@ class RemoteInferenceResponse:
parameters: dict[str, str | int | bool] = field(default_factory=dict) parameters: dict[str, str | int | bool] = field(default_factory=dict)
outputs: dict[str, RemoteTensor | Tensor] = field(default_factory=dict) outputs: dict[str, RemoteTensor | Tensor] = field(default_factory=dict)
store_outputs_in_response: set[str] = field(default_factory=set) store_outputs_in_response: set[str] = field(default_factory=set)
error: Optional[TritonError] = None error: Optional[RequestPlaneError] = None
classification_label: Optional[str] = None classification_label: Optional[str] = None
final: bool = False final: bool = False
...@@ -253,7 +255,9 @@ class RemoteInferenceResponse: ...@@ -253,7 +255,9 @@ class RemoteInferenceResponse:
): ):
for name, value in self.outputs.items(): for name, value in self.outputs.items():
if not isinstance(value, RemoteTensor): if not isinstance(value, RemoteTensor):
if not isinstance(value, Tensor): if not isinstance(value, Tensor) and not isinstance(
value, TritonTensor
):
tensor = Tensor._from_object(value) tensor = Tensor._from_object(value)
else: else:
tensor = value tensor = value
...@@ -294,29 +298,6 @@ class RemoteInferenceResponse: ...@@ -294,29 +298,6 @@ class RemoteInferenceResponse:
self._set_model_infer_response_outputs(remote_response, data_plane) self._set_model_infer_response_outputs(remote_response, data_plane)
return remote_response return remote_response
@staticmethod
def from_local_response(
local_response: InferenceResponse, store_outputs_in_response: bool = False
):
result = RemoteInferenceResponse(
local_response.model.name,
local_response.model.version,
None,
local_response.request_id,
final=local_response.final,
)
for tensor_name, tensor_value in local_response.outputs.items():
result.outputs[tensor_name] = tensor_value
if store_outputs_in_response:
result.store_outputs_in_response.add(tensor_name)
for parameter_name, parameter_value in local_response.parameters.items():
result.parameters[parameter_name] = parameter_value
result.error = local_response.error
return result
@staticmethod @staticmethod
def from_model_infer_response( def from_model_infer_response(
request: RemoteInferenceRequest, request: RemoteInferenceRequest,
......
...@@ -20,13 +20,8 @@ from typing import Optional, Sequence ...@@ -20,13 +20,8 @@ from typing import Optional, Sequence
import cupy import cupy
from cupy_backends.cuda.api.runtime import CUDARuntimeError from cupy_backends.cuda.api.runtime import CUDARuntimeError
from tritonserver import DataType, InvalidArgumentError, MemoryType, Tensor
# TODO
# Export from tritonserver
from tritonserver._api._dlpack import DLDeviceType
from tritonserver._api._tensor import DeviceOrMemoryType
from triton_distributed.icp._dlpack import DeviceOrMemoryType, DLDeviceType
from triton_distributed.icp.data_plane import ( from triton_distributed.icp.data_plane import (
DataPlane, DataPlane,
get_icp_data_type, get_icp_data_type,
...@@ -34,7 +29,10 @@ from triton_distributed.icp.data_plane import ( ...@@ -34,7 +29,10 @@ from triton_distributed.icp.data_plane import (
get_icp_shape, get_icp_shape,
get_icp_tensor_size, get_icp_tensor_size,
) )
from triton_distributed.icp.data_type import DataType
from triton_distributed.icp.memory_type import MemoryType
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from triton_distributed.icp.tensor import Tensor
# Run cupy's cuda.is_available once to # Run cupy's cuda.is_available once to
# avoid the exception hitting runtime code. # avoid the exception hitting runtime code.
...@@ -74,7 +72,7 @@ class RemoteTensor: ...@@ -74,7 +72,7 @@ class RemoteTensor:
if not self._local_tensor: if not self._local_tensor:
self._local_tensor = self.data_plane.get_tensor(self.remote_tensor) self._local_tensor = self.data_plane.get_tensor(self.remote_tensor)
if self._local_tensor is None: if self._local_tensor is None:
raise InvalidArgumentError("Not able to resolve Tensor locally") raise ValueError("Not able to resolve Tensor locally")
return self._local_tensor return self._local_tensor
@property @property
......
...@@ -20,12 +20,24 @@ import os ...@@ -20,12 +20,24 @@ import os
import uuid import uuid
from typing import Optional from typing import Optional
try:
import tritonserver
from tritonserver import DataType as TritonDataType
from tritonserver import InvalidArgumentError
from tritonserver import MemoryBuffer as TritonMemoryBuffer
from tritonserver import MemoryType as TritonMemoryType
from tritonserver import Server as TritonCore
from tritonserver import Tensor as TritonTensor
from tritonserver._api._response import InferenceResponse
except ImportError as e:
raise ImportError("Triton Core is not installed") from e
from google.protobuf import json_format, text_format from google.protobuf import json_format, text_format
from tritonclient.grpc import model_config_pb2 from tritonclient.grpc import model_config_pb2
from tritonserver import InvalidArgumentError, Server
from triton_distributed.icp.data_plane import DataPlane from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.icp.tensor import Tensor
from triton_distributed.runtime.logger import get_logger from triton_distributed.runtime.logger import get_logger
from triton_distributed.runtime.operator import Operator from triton_distributed.runtime.operator import Operator
from triton_distributed.runtime.remote_request import RemoteInferenceRequest from triton_distributed.runtime.remote_request import RemoteInferenceRequest
...@@ -37,12 +49,12 @@ class TritonCoreOperator(Operator): ...@@ -37,12 +49,12 @@ class TritonCoreOperator(Operator):
self, self,
name: str, name: str,
version: int, version: int,
triton_core: Server,
request_plane: RequestPlane, request_plane: RequestPlane,
data_plane: DataPlane, data_plane: DataPlane,
parameters: dict, parameters: dict,
repository: Optional[str] = None, repository: Optional[str] = None,
logger: logging.Logger = get_logger(__name__), logger: logging.Logger = get_logger(__name__),
triton_core: Optional[TritonCore] = None,
): ):
self._repository = repository self._repository = repository
self._name = name self._name = name
...@@ -56,11 +68,14 @@ class TritonCoreOperator(Operator): ...@@ -56,11 +68,14 @@ class TritonCoreOperator(Operator):
"store_outputs_in_response", False "store_outputs_in_response", False
) )
if self._triton_core is None:
raise ValueError("Triton Core required for TritonCoreOperator")
if not self._repository: if not self._repository:
self._repository = "." self._repository = "."
if repository: if repository:
triton_core.register_model_repository(repository) self._triton_core.register_model_repository(repository)
parameter_config = self._parameters.get("config", None) parameter_config = self._parameters.get("config", None)
...@@ -88,7 +103,79 @@ class TritonCoreOperator(Operator): ...@@ -88,7 +103,79 @@ class TritonCoreOperator(Operator):
model_config = {"config": parameter_config} model_config = {"config": parameter_config}
else: else:
model_config = None model_config = None
self._local_model = self._triton_core.load(self._name, model_config) self._triton_core_model = self._triton_core.load(self._name, model_config)
@staticmethod
def _triton_tensor(tensor: Tensor) -> TritonTensor:
return TritonTensor(
TritonDataType(tensor.data_type),
tensor.shape,
TritonMemoryBuffer(
tensor.memory_buffer.data_ptr,
TritonMemoryType(tensor.memory_buffer.memory_type),
tensor.memory_buffer.memory_type_id,
tensor.memory_buffer.size,
tensor.memory_buffer.owner,
),
)
@staticmethod
def _triton_core_request(
request: RemoteInferenceRequest, model: tritonserver.Model
) -> tritonserver.InferenceRequest:
triton_core_request = model.create_request()
if request.request_id is not None:
triton_core_request.request_id = request.request_id
if request.priority is not None:
triton_core_request.priority = request.priority
if request.timeout is not None:
triton_core_request.timeout = request.timeout
if request.correlation_id is not None:
triton_core_request.correlation_id = request.correlation_id
TritonCoreOperator._set_inputs(request, triton_core_request)
TritonCoreOperator._set_parameters(request, triton_core_request)
return triton_core_request
@staticmethod
def _set_inputs(
request: RemoteInferenceRequest, local_request: tritonserver.InferenceRequest
):
for input_name, remote_tensor in request.inputs.items():
local_request.inputs[input_name] = TritonCoreOperator._triton_tensor(
remote_tensor.local_tensor
)
@staticmethod
def _set_parameters(
request: RemoteInferenceRequest, local_request: tritonserver.InferenceRequest
):
for parameter_name, parameter_value in request.parameters.items():
local_request.parameters[parameter_name] = parameter_value
@staticmethod
def _remote_response(
triton_core_response: InferenceResponse, store_outputs_in_response: bool = False
) -> RemoteInferenceResponse:
result = RemoteInferenceResponse(
triton_core_response.model.name,
triton_core_response.model.version,
None,
triton_core_response.request_id,
final=triton_core_response.final,
)
for tensor_name, tensor_value in triton_core_response.outputs.items():
result.outputs[tensor_name] = tensor_value
if store_outputs_in_response:
result.store_outputs_in_response.add(tensor_name)
for parameter_name, parameter_value in triton_core_response.parameters.items():
result.parameters[parameter_name] = parameter_value
result.error = triton_core_response.error
return result
async def execute(self, requests: list[RemoteInferenceRequest]) -> None: async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
request_id_map = {} request_id_map = {}
...@@ -96,7 +183,9 @@ class TritonCoreOperator(Operator): ...@@ -96,7 +183,9 @@ class TritonCoreOperator(Operator):
for request in requests: for request in requests:
self._logger.debug("\n\nReceived request: \n\n%s\n\n", request) self._logger.debug("\n\nReceived request: \n\n%s\n\n", request)
try: try:
local_request = request.to_local_request(self._local_model) triton_core_request = TritonCoreOperator._triton_core_request(
request, self._triton_core_model
)
except Exception as e: except Exception as e:
message = f"Can't resolve tensors for request, ignoring request,{e}" message = f"Can't resolve tensors for request, ignoring request,{e}"
self._logger.error(message) self._logger.error(message)
...@@ -107,25 +196,27 @@ class TritonCoreOperator(Operator): ...@@ -107,25 +196,27 @@ class TritonCoreOperator(Operator):
request_id = str(uuid.uuid1()) request_id = str(uuid.uuid1())
original_id = None original_id = None
if local_request.request_id is not None: if triton_core_request.request_id is not None:
original_id = local_request.request_id original_id = triton_core_request.request_id
local_request.request_id = request_id triton_core_request.request_id = request_id
request_id_map[request_id] = (request.response_sender(), original_id) request_id_map[request_id] = (request.response_sender(), original_id)
local_request.response_queue = response_queue triton_core_request.response_queue = response_queue
self._local_model.async_infer(local_request) self._triton_core_model.async_infer(triton_core_request)
while request_id_map: while request_id_map:
local_response = await response_queue.get() triton_core_response = await response_queue.get()
remote_response = RemoteInferenceResponse.from_local_response( remote_response = TritonCoreOperator._remote_response(
local_response, self._store_outputs_in_response triton_core_response, self._store_outputs_in_response
) )
response_sender, original_id = request_id_map[local_response.request_id] response_sender, original_id = request_id_map[
triton_core_response.request_id
]
remote_response.request_id = original_id remote_response.request_id = original_id
if local_response.final: if triton_core_response.final:
del request_id_map[local_response.request_id] del request_id_map[triton_core_response.request_id]
self._logger.debug("\n\nSending response\n\n%s\n\n", remote_response) self._logger.debug("\n\nSending response\n\n%s\n\n", remote_response)
await response_sender.send(remote_response) await response_sender.send(remote_response)
...@@ -24,7 +24,17 @@ from collections import Counter ...@@ -24,7 +24,17 @@ from collections import Counter
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, Type from typing import TYPE_CHECKING, Optional, Type
import tritonserver try:
from tritonserver import ModelControlMode as ModelControlMode
from tritonserver import Server as TritonCore
from triton_distributed.runtime.triton_core_operator import TritonCoreOperator
TRITON_CORE_AVAILABLE = True
except ImportError:
TRITON_CORE_AVAILABLE = False
TritonCoreOperator = type(None)
TritonCore = type(None) # type: ignore[misc,assignment]
from triton_distributed.icp.data_plane import DataPlane from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.nats_request_plane import NatsRequestPlane from triton_distributed.icp.nats_request_plane import NatsRequestPlane
...@@ -36,7 +46,6 @@ from triton_distributed.runtime.remote_request import ( ...@@ -36,7 +46,6 @@ from triton_distributed.runtime.remote_request import (
RemoteInferenceRequest, RemoteInferenceRequest,
RemoteResponseSender, RemoteResponseSender,
) )
from triton_distributed.runtime.triton_core_operator import TritonCoreOperator
if TYPE_CHECKING: if TYPE_CHECKING:
import uvicorn import uvicorn
...@@ -89,7 +98,7 @@ class Worker: ...@@ -89,7 +98,7 @@ class Worker:
self._metrics_port = config.metrics_port self._metrics_port = config.metrics_port
self._metrics_server: Optional[uvicorn.Server] = None self._metrics_server: Optional[uvicorn.Server] = None
self._component_id = self._request_plane.component_id self._component_id = self._request_plane.component_id
self._triton_core: Optional[tritonserver.Server] = None self._triton_core: Optional[TritonCore] = None
self._log_file: Optional[pathlib.Path] = None self._log_file: Optional[pathlib.Path] = None
if self._log_dir: if self._log_dir:
path = pathlib.Path(self._log_dir) path = pathlib.Path(self._log_dir)
...@@ -148,6 +157,10 @@ class Worker: ...@@ -148,6 +157,10 @@ class Worker:
class_ == TritonCoreOperator class_ == TritonCoreOperator
or issubclass(class_, TritonCoreOperator) or issubclass(class_, TritonCoreOperator)
) and not self._triton_core: ) and not self._triton_core:
if not TRITON_CORE_AVAILABLE:
raise ValueError(
"Please install Triton Core to use a Triton Core Operator"
)
if not self._consolidate_logs and self._log_file: if not self._consolidate_logs and self._log_file:
log_file = pathlib.Path(self._log_file) log_file = pathlib.Path(self._log_file)
stem = log_file.stem stem = log_file.stem
...@@ -157,24 +170,24 @@ class Worker: ...@@ -157,24 +170,24 @@ class Worker:
) )
else: else:
triton_log_path = str(self._log_file) triton_log_path = str(self._log_file)
self._triton_core = tritonserver.Server( self._triton_core = TritonCore(
model_repository=".", model_repository=".",
log_error=True, log_error=True,
log_verbose=self._log_level, log_verbose=self._log_level,
strict_model_config=False, strict_model_config=False,
model_control_mode=tritonserver.ModelControlMode.EXPLICIT, model_control_mode=ModelControlMode.EXPLICIT,
log_file=triton_log_path, log_file=triton_log_path,
).start(wait_until_ready=True) ).start(wait_until_ready=True)
operator = class_( operator = class_(
operator_config.name, operator_config.name,
operator_config.version, operator_config.version,
self._triton_core,
self._request_plane, self._request_plane,
self._data_plane, self._data_plane,
operator_config.parameters, operator_config.parameters,
operator_config.repository, operator_config.repository,
operator_logger, operator_logger,
self._triton_core,
) )
except Exception as e: except Exception as e:
logger.exception( logger.exception(
......
...@@ -25,12 +25,12 @@ class AddMultiplyDivide(Operator): ...@@ -25,12 +25,12 @@ class AddMultiplyDivide(Operator):
self, self,
name, name,
version, version,
triton_core,
request_plane, request_plane,
data_plane, data_plane,
parameters, parameters,
repository, repository,
logger, logger,
triton_core,
): ):
self._triton_core = triton_core self._triton_core = triton_core
self._request_plane = request_plane self._request_plane = request_plane
......
...@@ -28,12 +28,12 @@ class Identity(Operator): ...@@ -28,12 +28,12 @@ class Identity(Operator):
self, self,
name, name,
version, version,
triton_core,
request_plane, request_plane,
data_plane, data_plane,
params, params,
repository, repository,
logger, logger,
triton_core,
): ):
self._triton_core = triton_core self._triton_core = triton_core
self._request_plane = request_plane self._request_plane = request_plane
......
...@@ -26,12 +26,12 @@ class MockDisaggregatedServing(Operator): ...@@ -26,12 +26,12 @@ class MockDisaggregatedServing(Operator):
self, self,
name, name,
version, version,
triton_core,
request_plane, request_plane,
data_plane, data_plane,
params, params,
repository, repository,
logger, logger,
triton_core,
): ):
self._triton_core = triton_core self._triton_core = triton_core
self._request_plane = request_plane self._request_plane = request_plane
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment