Commit 0bfd9a76 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

refactor: remove python native runtime

parent 8f741f14
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import dataclasses
import uuid
from datetime import datetime
from typing import Any, Optional, Type
import msgspec
from triton_distributed.icp.event_plane import Event, EventTopic
@dataclasses.dataclass
class EventMetadata:
"""
Class keeps metadata of an event.
"""
event_id: uuid.UUID
event_type: str
timestamp: datetime
component_id: uuid.UUID
event_topic: Optional[EventTopic] = None
def _deserialize_metadata(event_metadata_serialized: bytes):
event_metadata_dict = msgspec.json.decode(event_metadata_serialized)
topic_meta = event_metadata_dict["event_topic"]
topic_list = topic_meta["event_topic"].split(".") if topic_meta else []
topic_obj = EventTopic(topic_list)
metadata = EventMetadata(
**{
**event_metadata_dict,
"event_topic": topic_obj,
"event_id": uuid.UUID(event_metadata_dict["event_id"]),
"component_id": uuid.UUID(event_metadata_dict["component_id"]),
"timestamp": datetime.fromisoformat(event_metadata_dict["timestamp"]),
}
)
return metadata
def _serialize_metadata(event_metadata: EventMetadata) -> bytes:
def hook(obj):
if isinstance(obj, uuid.UUID):
return str(obj)
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, EventTopic):
return list(obj.event_topic.split("."))
else:
raise NotImplementedError(f"Type {type(obj)} is not serializable.")
json_string = msgspec.json.encode(event_metadata, enc_hook=hook)
return json_string
def _get_type(type_name: str):
# Check in builtins for the type
builtin_type = getattr(builtins, type_name, None)
if builtin_type and isinstance(builtin_type, type):
return builtin_type
# Check in globals for the type
global_type = globals().get(type_name)
if global_type and isinstance(global_type, type):
return global_type
return None
class OnDemandEvent(Event):
"""LazyEvent class for representing events."""
def __init__(
self,
payload: bytes,
event_metadata_serialized: bytes,
event_metadata: Optional[EventMetadata] = None,
):
"""Initialize the event.
Args:
event_metadata (EventMetadata): Event metadata
event (bytes): Event payload
"""
self._payload = payload
self._event_metadata_serialized = event_metadata_serialized
self._event_metadata = event_metadata
@property
def _metadata(self):
if not self._event_metadata:
self._event_metadata = _deserialize_metadata(
self._event_metadata_serialized
)
return self._event_metadata
@property
def event_id(self) -> uuid.UUID:
return self._metadata.event_id
@property
def event_type(self) -> str:
return self._metadata.event_type
@property
def timestamp(self) -> datetime:
return self._metadata.timestamp
@property
def component_id(self) -> uuid.UUID:
return self._metadata.component_id
@property
def event_topic(self) -> Optional[EventTopic]:
return self._metadata.event_topic
@property
def payload(self) -> bytes:
return self._payload
def typed_payload(self, payload_type: Optional[Type | str] = None) -> Any:
if payload_type is None:
payload_type = self.event_type
if isinstance(payload_type, str):
payload_type = _get_type(payload_type)
if payload_type is not None and payload_type is not bytes:
try:
return msgspec.json.decode(self._payload, type=payload_type)
except Exception as e:
raise ValueError(
f"Unable to convert payload {self._payload!r} to type {payload_type} from event type {self.event_type}"
) from e
elif payload_type is bytes:
return bytes(self._payload)
else:
raise ValueError(
f"Unable to convert payload {self._payload!r} to type {payload_type} from event type {self.event_type}"
)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract Class for interacting with the Triton Distributed Inter-Component Protocol Control Plane"""
import abc
import uuid
from typing import AsyncIterator, Awaitable, Callable, Optional
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
ICP_REQUEST_ID = "icp_request_id"
ICP_FINAL_RESPONSE = "icp_final_response"
ICP_RESPONSE_FROM_URI = "icp_response_from_uri"
ICP_COMPONENT_ID = "icp_component_id"
ICP_RESPONSE_TO_URI = "icp_response_to_uri"
ICP_REQUEST_TO_URI = "icp_request_to_uri"
ICP_REQUEST_CANCELLED = "icp_request_cancelled"
ICP_ERROR = "icp_response_error"
class RequestPlaneError(Exception):
pass
def get_icp_request_id(
message: ModelInferRequest | ModelInferResponse,
) -> uuid.UUID | None:
if ICP_REQUEST_ID not in message.parameters:
return None
return uuid.UUID(message.parameters[ICP_REQUEST_ID].string_param)
def set_icp_request_id(
message: ModelInferRequest | ModelInferResponse, value: uuid.UUID
) -> None:
message.parameters[ICP_REQUEST_ID].string_param = str(value)
def get_icp_response_error(message: ModelInferResponse) -> RequestPlaneError | None:
if ICP_ERROR not in message.parameters:
return None
return RequestPlaneError(message.parameters[ICP_ERROR].string_param)
def set_icp_response_error(
message: ModelInferResponse, value: RequestPlaneError
) -> None:
message.parameters[ICP_ERROR].string_param = str(value)
def get_icp_final_response(
message: ModelInferResponse,
) -> bool:
if ICP_FINAL_RESPONSE not in message.parameters:
return False
return message.parameters[ICP_FINAL_RESPONSE].bool_param
def set_icp_final_response(message: ModelInferResponse, value: bool) -> None:
message.parameters[ICP_FINAL_RESPONSE].bool_param = value
def get_icp_response_to_uri(message: ModelInferRequest) -> str | None:
if ICP_RESPONSE_TO_URI not in message.parameters:
return None
return message.parameters[ICP_RESPONSE_TO_URI].string_param
def get_icp_component_id(
message: ModelInferRequest | ModelInferResponse,
) -> uuid.UUID | None:
if ICP_COMPONENT_ID not in message.parameters:
return None
return uuid.UUID(message.parameters[ICP_COMPONENT_ID].string_param)
def set_icp_component_id(
message: ModelInferRequest | ModelInferResponse, value: uuid.UUID
) -> None:
message.parameters[ICP_COMPONENT_ID].string_param = str(value)
def set_icp_response_to_uri(message: ModelInferRequest, value: str) -> None:
message.parameters[ICP_RESPONSE_TO_URI].string_param = value
def get_icp_request_to_uri(message: ModelInferRequest) -> str | None:
if ICP_REQUEST_TO_URI not in message.parameters:
return None
return message.parameters[ICP_REQUEST_TO_URI].string_param
def set_icp_request_to_uri(message: ModelInferRequest, value: str) -> None:
message.parameters[ICP_REQUEST_TO_URI].string_param = value
def get_icp_response_from_uri(message: ModelInferResponse) -> str | None:
if ICP_RESPONSE_FROM_URI not in message.parameters:
return None
return message.parameters[ICP_RESPONSE_FROM_URI].string_param
def set_icp_response_from_uri(message: ModelInferResponse, value: str) -> None:
message.parameters[ICP_RESPONSE_FROM_URI].string_param = value
class RequestPlane(abc.ABC):
@property
@abc.abstractmethod
def component_id(self) -> uuid.UUID:
pass
@abc.abstractmethod
async def connect(self) -> None:
pass
@abc.abstractmethod
async def pull_requests(
self,
model_name: str,
model_version: str,
number_requests: int = 1,
timeout: Optional[float] = None,
) -> AsyncIterator[ModelInferRequest]:
pass
@abc.abstractmethod
async def post_response(
self,
request: ModelInferRequest,
responses: AsyncIterator[ModelInferResponse] | ModelInferResponse,
) -> None:
pass
@abc.abstractmethod
async def post_request(
self,
request: ModelInferRequest,
*,
component_id: Optional[uuid.UUID] = None,
response_iterator: bool = True,
response_handler: Optional[
Callable[[ModelInferResponse], None | Awaitable[None]]
] = None,
) -> AsyncIterator[ModelInferResponse]:
pass
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import ctypes
import struct
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Sequence
import numpy
from triton_distributed.icp._dlpack import (
DATA_TYPE_TO_DLPACK_DTYPE,
MEMORY_TYPE_TO_DLPACK_DEVICE_TYPE,
DeviceOrMemoryType,
DLDevice,
DLDeviceType,
DLManagedTensor,
DLPackObject,
c_str_dltensor,
parse_device_or_memory_type,
)
from triton_distributed.icp.data_type import NUMPY_TO_DATA_TYPE, DataType
from triton_distributed.icp.memory_buffer import MemoryBuffer
from triton_distributed.icp.memory_type import MemoryType
try:
import cupy
except ImportError:
cupy = None
@dataclass
class Tensor:
"""Class representing a Tensor.
Parameters
----------
data_type : DataType
Data type of the tensor.
shape : Sequence[int]
Shape of the tensor.
memory_buffer : MemoryBuffer
Memory buffer containing the tensor data.
"""
data_type: DataType
shape: Sequence[int]
memory_buffer: MemoryBuffer
@property
def data_ptr(self) -> int:
"""Get the pointer to the tensor's data.
Returns
-------
int
The pointer to the tensor's data.
"""
return self.memory_buffer.data_ptr
@property
def memory_type(self) -> MemoryType:
"""Get the memory type of the tensor.
Returns
-------
MemoryType
The memory type of the tensor.
"""
return self.memory_buffer.memory_type
@property
def memory_type_id(self) -> int:
"""Get the ID representing the memory type of the tensor.
Returns
-------
int
The ID representing the memory type of the tensor.
"""
return self.memory_buffer.memory_type_id
@property
def size(self) -> int:
"""Get the size of the tensor's data in bytes.
Returns
-------
int
The size of the tensor's data in bytes.
"""
return self.memory_buffer.size
def _sync_on_requested_stream(self, requested_stream_ptr):
"""Stream synchronization based on cupy implementation.
Record event on requested stream if different from current
stream.
See
`https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html`
and
`https://github.com/cupy/cupy/blob/f9563bcd5c674623f80ce975b590f9c860b44ed6/cupy/_core/core.pyx#L281`.
for more details.
Parameters
----------
requested_stream_ptr :
requested stream as defined by DLPack protocol
Raises
------
unsupported
If synchronization can not be done
"""
current_stream = None
unsupported = ValueError(
f"DLPack stream synchronization on memory type {self.memory_type} and stream {requested_stream_ptr} not supported"
)
if requested_stream_ptr is not None and not isinstance(
requested_stream_ptr, int
):
raise unsupported
if self.memory_type != MemoryType.GPU:
if requested_stream_ptr not in (None, 0):
raise unsupported
return
if cupy is None:
raise unsupported
# NOTE: Technically this is not required by the protocol. It is the
# responsibility of the caller(consumer) to ensure that
# we are on the correct device. Added to ensure
# the semantics are correct - but should be a no-op.
# May be removed in the future.
with cupy.cuda.Device(self.memory_type_id):
current_stream = cupy.cuda.get_current_stream()
curr_stream_ptr = current_stream.ptr
# Based on cupy documentation
# cupy.cuda.Stream.null.ptr is legacy default stream
# cupy.cuda.Stream.ptds.ptr is per thread default stream
if curr_stream_ptr == 0:
curr_stream_ptr = cupy.cuda.Stream.null.ptr
if requested_stream_ptr in (None, 0, 1):
requested_stream_ptr = cupy.cuda.Stream.null.ptr
if requested_stream_ptr in (2,):
requested_stream_ptr = cupy.cuda.Stream.ptds.ptr
if requested_stream_ptr >= 0 and curr_stream_ptr != requested_stream_ptr:
next_stream = cupy.cuda.ExternalStream(requested_stream_ptr)
event = current_stream.record()
next_stream.wait_event(event)
def __dlpack__(self, *, stream=None):
"""Convert the tensor to a DLPack-compatible object.
Parameters
----------
stream : Any, optional
Currently Ignored parameter, by default None.
Returns
-------
Any
A DLPack-compatible object representing the tensor.
"""
self._sync_on_requested_stream(stream)
dl_managed_tensor = self._create_managed_tensor()
pycapsule = ctypes.pythonapi.PyCapsule_New(
ctypes.byref(dl_managed_tensor),
c_str_dltensor,
Tensor._pycapsule_deleter,
)
return pycapsule
def __dlpack_device__(self) -> tuple[DLDeviceType, int]:
"""Get the DLPack device information for the tensor.
Returns
-------
tuple[DLDeviceType, int]
A tuple representing the DLPack device information (device type, device ID).
"""
return (
MEMORY_TYPE_TO_DLPACK_DEVICE_TYPE[self.memory_type],
self.memory_type_id,
)
def to_string_array(self) -> numpy.ndarray:
"""Deserialize BYTES Tensor into numpy array of strings.
If memory is not on the host the tensor data will be copied to
the host before deserialization.
Returns
-------
numpy.ndarray
A numpy array of objects representing the BYTES tensor.
Examples
--------
numpy_ndarray = response.outputs["text_output"].to_string_array()
"""
return self.to_bytes_array().astype(str)
def to_bytes_array(self) -> numpy.ndarray:
"""Deserialize BYTES Tensor into numpy array.
If memory is not on the host the tensor data will be copied to
the host before deserialization.
Returns
-------
numpy.ndarray
A numpy array of objects representing the BYTES tensor.
Examples
--------
numpy_ndarray = response.outputs["text_output"].to_bytes_array()
"""
if self.data_type != DataType.BYTES:
raise ValueError(
f"Tensor has data type {self.data_type} not {DataType.BYTES}"
)
# Reshape into 1d array of bytes on host
original_data_type = self.data_type
original_shape = self.shape
self.data_type = DataType.UINT8
self.shape = [self.size]
numpy_ndarray = self._to_numpy_on_host()
# Deserialize bytes array and reshape
self.shape = original_shape
self.data_type = original_data_type
return Tensor._deserialize_bytes_array(numpy_ndarray).reshape(self.shape)
@staticmethod
def from_string_array(string_array: list[str] | numpy.ndarray) -> Tensor:
"""Create BYTES Tensor from numpy array of strings or list of strings.
Creates a tensor of type BYTES from a list of strings,
or numpy array of type str_. The
method allocates new host memory to store the serialized
tensor.
Parameters
----------
string_array : list[str] | numpy.ndarray
an array like object to convert
Returns
-------
Tensor
Raises
------
ValueError
If the given object can not be converted.
Examples
--------
tensor = Tensor.from_string_array(numpy.array(["hello"]))
tensor = Tensor.from_string_array(["hello"])
"""
return Tensor.from_bytes_array(string_array)
@staticmethod
def from_bytes_array(
bytes_array: list[str] | list[bytes] | numpy.ndarray,
) -> Tensor:
"""Create BYTES Tensor from numpy array or list
Creates a tensor of type BYTES from a list of strings,
bytes or a numpy array of type object_, bytes_, or str_. The
method allocates new host memory to store the serialized
tensor.
Parameters
----------
bytes_array : list[str | bytes] | numpy.ndarray
an array like object to convert
Returns
-------
Tensor
Raises
------
ValueError
If the given object can not be converted.
Examples
--------
tensor = Tensor.from_bytes_array(numpy.array(["hello"]))
tensor = Tensor.from_bytes_array(["hello"])
"""
result = Tensor._from_object(bytes_array)
if result.data_type != DataType.BYTES:
raise ValueError(
f"Unsupported conversion from {bytes_array} to BYTES Tensor. Got {result.data_type}"
)
return result
@staticmethod
def _from_object(obj: list[Any] | numpy.ndarray | Any) -> Tensor:
"""Create a tensor from an object.
Creates a tensor from an object using specific conversion
methods if available or falls back to using __from_dlpack__.
Specific conversions are currently supported for:
list[obj: Any] : implicitly converted to numpy.array()
numpy.ndarray : serialized if required to BYTES tensor
Parameters
----------
obj : list[Any] | numpy.ndarray | Any
The input object to create the tensor from.
Returns
-------
Tensor
A new tensor created from the specified object.
Examples
--------
tensor = Tensor.from_object(numpy.array(["hello"]))
tensor = Tensor.from_object(["hello"])
"""
if type(obj) in Tensor._from_converters:
return Tensor._from_converters[type(obj)](obj)
elif hasattr(obj, "__dlpack__"):
return Tensor.from_dlpack(obj)
else:
raise ValueError(
f"Input type {type(obj)} not supported. Must be one of {list(Tensor._from_converters.keys())} or the type must support __dlpack__"
)
@staticmethod
def from_dlpack(obj: Any) -> Tensor:
"""Create a tensor from a DLPack-compatible object.
Parameters
----------
obj : Any
The DLPack-compatible object.
Returns
-------
Tensor
A new tensor created from the DLPack-compatible object.
Examples
--------
tensor = Tensor.from_dlpack(numpy.array([0,1,2], dtype=numpy.float16))
tensor = Tensor.from_dlpack(torch.zeros(100, dtype=torch.float16))
"""
dlpack_object = DLPackObject(obj)
data_type = dlpack_object.data_type
shape = dlpack_object.shape
memory_buffer = MemoryBuffer._from_dlpack_object(
obj, dlpack_object=dlpack_object
)
return Tensor(data_type, shape, memory_buffer)
def to_host(self) -> Tensor:
"""Move the tensor to CPU memory from device memory
Returns
-------
Tensor
The tensor moved to the CPU.
Examples
--------
tensor = Tensor.from_dlpack(torch.zeros(100, dtype=torch.float16).to("cuda"))
numpy_nd_array = numpy.array(tensor.to_host())
"""
return self.to_device("cpu")
def to_device(self, device: DeviceOrMemoryType) -> Tensor:
"""Move the tensor to the specified device.
Parameters
----------
device : DeviceOrMemoryType
The target device. Device can be specified as a string,
MemoryType, tuple [MemoryType, memory_type__id], or
tuple[DLDeviceType, device_id].
Returns
-------
Tensor
The tensor moved to the specified device.
Examples
--------
tensor_cpu = tritonserver.Tensor.from_dlpack(numpy.array([0,1,2], dtype=numpy.float16))
# Different ways to specify the device
tensor_gpu = tensor_cpu.to_device(MemoryType.GPU)
tensor_gpu = tensor_cpu.to_device((MemoryType.GPU,0))
tensor_gpu = tensor_cpu.to_device((DLDeviceType.kDLCUDA,0))
tensor_gpu = tensor_cpu.to_device("gpu")
tensor_gpu = tensor_cpu.to_device("gpu:0")
ndarray_gpu = cupy.from_dlpack(tensor_gpu)
ndarray_gpu[0] = ndarray_gpu.mean()
tensor_cpu = tensor_gpu.to_device("cpu")
ndarray_cpu = numpy.from_dlpack(tensor_cpu)
assert ndarray_cpu[0] == ndarray_gpu[0]
"""
memory_type, memory_type_id = parse_device_or_memory_type(device)
if self.memory_type == memory_type and self.memory_type_id == memory_type_id:
return self
if self.memory_type == MemoryType.CPU_PINNED and memory_type == MemoryType.CPU:
return self
if cupy is not None:
if self.memory_type in (MemoryType.CPU, MemoryType.CPU_PINNED):
ndarray = numpy.from_dlpack(self)
else:
ndarray = cupy.from_dlpack(self)
if memory_type == MemoryType.CPU:
return Tensor.from_dlpack(cupy.asnumpy(ndarray))
if memory_type == MemoryType.GPU:
with cupy.cuda.Device(memory_type_id):
return Tensor.from_dlpack(cupy.asarray(ndarray))
raise ValueError(
f"Conversion from {(self.memory_type,self.memory_type_id)} to {(memory_type, memory_type_id)} not supported."
)
def _to_numpy_on_host(self) -> numpy.ndarray:
if self.memory_type in (MemoryType.CPU, MemoryType.CPU_PINNED):
return numpy.from_dlpack(self)
if cupy is not None:
return cupy.asnumpy(cupy.from_dlpack(self))
raise ValueError(
f"Conversion from {self.memory_type} to numpy array not supported."
)
@staticmethod
def _deserialize_bytes_array(numpy_ndarray: numpy.ndarray) -> numpy.ndarray:
result = []
_buffer = memoryview(numpy_ndarray)
offset = 0
while offset < len(_buffer):
(item_length,) = struct.unpack_from("@I", _buffer, offset)
offset += 4
result.append(bytes(_buffer[offset : offset + item_length]))
offset += item_length
return numpy.array(result, dtype=numpy.object_)
@staticmethod
def _serialize_numpy_bytes_array(array: numpy.ndarray) -> numpy.ndarray:
result = []
for array_item in numpy.nditer(array, flags=["refs_ok"], order="C"):
item = array_item.item() # type: ignore
if not isinstance(item, bytes):
item = str(item).encode("utf-8")
result.append(struct.pack("@I", len(item)))
result.append(item)
return numpy.frombuffer(b"".join(result), dtype=numpy.byte)
@staticmethod
def _from_list(obj: list[Any]) -> Tensor:
try:
return Tensor._from_numpy(numpy.array(obj))
except Exception as e:
raise ValueError(f"Conversion from {obj} to tensor not supported.") from e
@staticmethod
def _from_numpy(obj: numpy.ndarray | numpy.generic) -> Tensor:
data_type = NUMPY_TO_DATA_TYPE[obj.dtype.type]
shape = obj.shape
if isinstance(obj, numpy.generic):
obj = numpy.asarray(obj)
if data_type == DataType.BYTES:
obj = Tensor._serialize_numpy_bytes_array(obj)
memory_buffer = MemoryBuffer(
data_ptr=obj.ctypes.data,
memory_type=MemoryType.CPU,
memory_type_id=0,
size=obj.itemsize * obj.size,
owner=obj,
)
return Tensor(data_type, shape, memory_buffer)
def _create_managed_tensor(self) -> DLManagedTensor:
# Allocates space for a managed tensor object
# and fills in the fields
#
# To ensure the lifetime of the managed tensor we create a
# context object that includes a newly created shape array and a
# reference to self
size = ctypes.c_size_t(ctypes.sizeof(DLManagedTensor))
address = ctypes.pythonapi.PyMem_RawMalloc(size)
dl_managed_tensor = DLManagedTensor.from_address(address)
dl_managed_tensor.dl_tensor.data = self.data_ptr
dl_managed_tensor.dl_tensor.device = DLDevice(
MEMORY_TYPE_TO_DLPACK_DEVICE_TYPE[self.memory_type],
self.memory_type_id,
)
dl_managed_tensor.dl_tensor.dtype = DATA_TYPE_TO_DLPACK_DTYPE[self.data_type]
dl_managed_tensor.dl_tensor.ndim = len(self.shape)
manager_ctx = _ManagerCtx(self)
dl_managed_tensor.dl_tensor.shape = manager_ctx.shape
dl_managed_tensor.dl_tensor.strides = manager_ctx.strides
dl_managed_tensor.dl_tensor.byte_offset = 0
dl_managed_tensor.deleter = Tensor._managed_tensor_deleter
dl_managed_tensor.manager_ctx = manager_ctx.reference()
return dl_managed_tensor
@staticmethod
@ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def _managed_tensor_deleter(handle: int) -> None:
dl_managed_tensor = DLManagedTensor.from_address(handle)
_ManagerCtx.release(dl_managed_tensor.manager_ctx)
ctypes.pythonapi.PyMem_RawFree(handle)
@staticmethod
@ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def _pycapsule_deleter(handle: ctypes.c_void_p) -> None:
try:
pycapsule: ctypes.py_object = ctypes.cast(handle, ctypes.py_object)
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, c_str_dltensor):
dl_managed_tensor = ctypes.pythonapi.PyCapsule_GetPointer(
pycapsule, c_str_dltensor
)
Tensor._managed_tensor_deleter(dl_managed_tensor)
ctypes.pythonapi.PyCapsule_SetDestructor(pycapsule, None)
except Exception as e:
print(f"Exception occurred while deleting capsule: {e}")
raise e
_from_converters: ClassVar[dict[type, Callable[[Any], Tensor]]] = dict(
{numpy.ndarray: _from_numpy, numpy.generic: _from_numpy, list: _from_list},
)
class _ManagerCtx:
# To ensure the lifetime of the managed tensor we create a
# context object that includes a newly created shape array and a
# reference to self
def __init__(self, tensor: Tensor) -> None:
self._tensor = tensor
self.shape = (ctypes.c_int64 * len(tensor.shape))(*tensor.shape)
self.strides = ctypes.POINTER(ctypes.c_int64)()
def reference(self) -> ctypes.c_void_p:
py_obj = ctypes.py_object(self)
ctypes.pythonapi.Py_IncRef(py_obj)
# Note: Could not find a direct way to cast a python object
# to a c_void_p. The mechanism is to either use id(self) or
# cast as described here:
#
# https://groups.google.com/g/dev-python/c/QRRqVC7gkf4/m/zH7l1gTXBwAJ
#
# To avoid relying on the behavior of id() we use the casting mechanism
return ctypes.POINTER(ctypes.c_void_p)(py_obj)[0] # type: ignore
@staticmethod
def release(reference: ctypes.c_void_p) -> None:
py_obj = ctypes.cast(reference, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(py_obj)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import contextlib
import logging
import threading
import uuid
from enum import IntEnum, auto
from functools import cached_property
from typing import Dict, Optional, Tuple
from urllib.parse import urlsplit
import cupy
import numpy
import ucp
from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp.data_plane import (
DataPlane,
DataPlaneError,
get_icp_data_type,
get_icp_memory_type,
get_icp_memory_type_id,
get_icp_shape,
get_icp_tensor_contents,
get_icp_tensor_size,
get_icp_tensor_uri,
set_icp_data_type,
set_icp_memory_type,
set_icp_memory_type_id,
set_icp_shape,
set_icp_tensor_contents,
set_icp_tensor_size,
set_icp_tensor_uri,
)
from triton_distributed.icp.data_type import DataType
from triton_distributed.icp.memory_buffer import MemoryBuffer
from triton_distributed.icp.memory_type import MemoryType
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from triton_distributed.icp.tensor import Tensor
LOGGER = logging.getLogger(__name__)
class UCP_DATA_PLANE_COMMANDS(IntEnum):
GET = auto()
CREATE_REFERENCE = auto()
RELEASE = auto()
# UCP has deadlocks when created multiple instances in a single process
# Create a singleton
_ucp_data_plane_singleton = None
def UcpDataPlane(
hostname: Optional[str] = None, port: int = 0, keep_endpoints_open: bool = False
):
global _ucp_data_plane_singleton
if _ucp_data_plane_singleton is None:
_ucp_data_plane_singleton = _UcpDataPlane(hostname, port, keep_endpoints_open)
return _ucp_data_plane_singleton
class _UcpDataPlane(DataPlane):
def __init__(
self,
hostname: Optional[str] = None,
port: int = 0,
keep_endpoints_open: bool = False,
) -> None:
self._tensor_store: Dict[uuid.UUID, Tensor] = {}
self._id_size = len(uuid.uuid1().bytes)
self._port = port
self._hostname = hostname or ucp.get_address()
self._event_loop_thread = threading.Thread(
target=self._run_event_loop, daemon=True
)
self._start_event = threading.Event()
self._listener: Optional[ucp.Listener] = None
self._event_loop: Optional[asyncio.AbstractEventLoop] = None
self._closed = False
self._keep_endpoints_open = keep_endpoints_open
self._endpoints: Dict[Tuple[str, int], ucp.Endpoint] = {}
LOGGER.debug(
"Creating UCP data plane with keep_endpoints_open=%s", keep_endpoints_open
)
@cached_property
def _cuda_is_available(self):
# Note: cuda.is_avalailable initializes cuda
# and can't be called when forking subprocesses
# care should be taken to only call it within
# subprocesses or use 'spawn'
try:
return cupy.cuda.is_available()
except CUDARuntimeError:
return False
@property
def hostname(self) -> str:
return self._hostname
@property
def port(self) -> int:
return self._port
def connect(self) -> None:
if self._event_loop is None:
self._event_loop_thread.start()
self._start_event.wait()
if self._listener is None or self._listener.closed():
raise DataPlaneError("Unable to start data plane")
async def _close(self, wait_for_release=0):
self._closed = True
if self._listener is not None:
if wait_for_release:
while self._tensor_store and wait_for_release:
await asyncio.sleep(1)
wait_for_release -= 1
self._listener.close()
self._listener = None
def close(self, wait_for_release=0):
if self._event_loop is None:
return
if self._event_loop.is_closed():
return
asyncio.run_coroutine_threadsafe(
self._close(wait_for_release),
self._event_loop,
).result()
def __del__(self):
self.close()
def _run_event_loop(self):
asyncio.run(self._serve())
async def _serve(self):
self._event_loop = asyncio.get_running_loop()
try:
self._listener = ucp.create_listener(self._send_receive, self._port)
self._port = self._listener.port
self._start_event.set()
except Exception:
self._listener = None
self._start_event.set()
while self._listener is not None and not self._listener.closed():
await asyncio.sleep(1)
async def _send_receive(self, ep):
while True:
tensor_id_bytes = numpy.empty(self._id_size, dtype="u1")
await ep.recv(tensor_id_bytes)
tensor_id = uuid.UUID(bytes=tensor_id_bytes.tobytes())
command = numpy.empty(1, dtype="u1")
await ep.recv(command)
if command == UCP_DATA_PLANE_COMMANDS.GET:
if tensor_id in self._tensor_store:
tensor = self._tensor_store[tensor_id]
array_module = numpy
if tensor.memory_type == MemoryType.CPU:
array_module = numpy
device_manager = contextlib.nullcontext()
elif tensor.memory_type == MemoryType.GPU:
array_module = cupy
device_manager = cupy.cuda.Device(
tensor.memory_buffer.memory_type_id
)
else:
raise ValueError(f"Invalid Memory Type {tensor.memory_type}")
with device_manager:
if tensor.data_type == DataType.BYTES:
array = tensor.memory_buffer.owner
else:
array = array_module.from_dlpack(tensor)
await ep.send(array)
elif command == UCP_DATA_PLANE_COMMANDS.CREATE_REFERENCE:
if tensor_id in self._tensor_store:
reference_tensor_id = uuid.uuid1()
self._tensor_store[reference_tensor_id] = self._tensor_store[
tensor_id
]
await ep.send(numpy.array(reference_tensor_id.bytes))
elif command == UCP_DATA_PLANE_COMMANDS.RELEASE:
if tensor_id in self._tensor_store:
del self._tensor_store[tensor_id]
await ep.send(numpy.array(tensor_id.bytes))
if not self._keep_endpoints_open:
break
await ep.close()
def _put_tensor(
self,
result: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
tensor: Tensor,
tensor_id: Optional[uuid.UUID] = None,
use_tensor_contents: bool = False,
):
if self._closed:
raise DataPlaneError("Adding tensor after close")
set_icp_data_type(result, tensor.data_type)
set_icp_shape(result, tensor.shape)
if use_tensor_contents:
set_icp_tensor_contents(result, tensor)
else:
if tensor_id is None:
tensor_id = uuid.uuid1()
self._tensor_store[tensor_id] = tensor
tensor_uri = f"ucp://{self._hostname}:{self._port}/{tensor_id}"
set_icp_tensor_uri(result, tensor_uri)
set_icp_memory_type(result, tensor.memory_buffer.memory_type)
set_icp_memory_type_id(result, tensor.memory_buffer.memory_type_id)
set_icp_tensor_size(result, tensor.size)
def put_input_tensor(
self,
tensor: Tensor,
tensor_id: Optional[uuid.UUID] = None,
use_tensor_contents: bool = False,
) -> ModelInferRequest.InferInputTensor:
"""Put an input tensor into the data plane or within
returned ModelInferRequest.InferInputTensor itself.
Args:
tensor: The tensor to put.
tensor_id: The id of the tensor to put.
If not provided, a new id will be generated.
use_tensor_contents: when True, tensor data will be
added directly to ModelInferRequest.InferInputTensor
contents field; otherwise tensor data will be sent
separately on the data plane.
Returns:
ModelInferRequest.InferInputTensor object.
"""
result = ModelInferRequest.InferInputTensor()
self._put_tensor(
result,
tensor,
tensor_id=tensor_id,
use_tensor_contents=use_tensor_contents,
)
return result
def put_output_tensor(
self,
tensor: Tensor,
tensor_id: Optional[uuid.UUID] = None,
use_tensor_contents: bool = False,
) -> ModelInferResponse.InferOutputTensor:
"""Put an output tensor into the data plane or within
returned ModelInferResponse.InferOutputTensor itself.
Args:
tensor: The tensor to put.
tensor_id: The id of the tensor to put.
If not provided, a new id will be generated.
use_tensor_contents: when True, tensor data will be
added directly to ModelInferResponse.InferInputTensor
contents field; otherwise tensor data will be sent
separately on the data plane.
Returns:
ModelInferResponse.InferOutputTensor object.
"""
result = ModelInferResponse.InferOutputTensor()
self._put_tensor(
result,
tensor,
tensor_id=tensor_id,
use_tensor_contents=use_tensor_contents,
)
return result
def _split_tensor_uri(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> tuple[uuid.UUID, str, int]:
tensor_uri = get_icp_tensor_uri(remote_tensor)
split_uri = urlsplit(tensor_uri)
path = str(split_uri.path).replace("/", "")
tensor_id = uuid.UUID(path)
host = split_uri.hostname
port = split_uri.port
if host is None or not isinstance(host, str):
raise DataPlaneError(f"Invalid host {host}")
if port is None:
raise DataPlaneError(f"Invalid Port {port}")
return tensor_id, host, port
async def _get_remote_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
requested_memory_type: Optional[MemoryType],
requested_memory_type_id: Optional[int],
) -> Tensor:
tensor_contents = get_icp_tensor_contents(remote_tensor)
if tensor_contents is not None:
return tensor_contents
tensor_size = get_icp_tensor_size(remote_tensor)
memory_type = get_icp_memory_type(remote_tensor)
data_type = get_icp_data_type(remote_tensor)
shape = get_icp_shape(remote_tensor)
tensor_id, host, port = self._split_tensor_uri(remote_tensor)
storage = None
if tensor_size is None:
raise DataPlaneError("tensor size can not be none")
if requested_memory_type is not None:
memory_type = requested_memory_type
if memory_type == MemoryType.GPU and self._cuda_is_available:
array_module = cupy
if requested_memory_type_id is not None:
device_manager = cupy.cuda.Device(requested_memory_type_id)
else:
device_manager = contextlib.nullcontext()
else:
array_module = numpy
device_manager = contextlib.nullcontext()
with device_manager:
storage = array_module.empty(tensor_size, dtype="u1")
try:
endpoint = await self._create_endpoint(host, port)
await endpoint.send(numpy.array(tensor_id.bytes))
await endpoint.send(
numpy.array(UCP_DATA_PLANE_COMMANDS.GET, dtype="u1")
)
await endpoint.recv(storage)
if not self._keep_endpoints_open:
await self._close_endpoint(host, port)
return Tensor(data_type, shape, MemoryBuffer.from_dlpack(storage))
except Exception as e:
raise DataPlaneError(f"Error Getting Tensor:\n{remote_tensor}") from e
async def _create_remote_tensor_reference(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
result: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
):
tensor_size = get_icp_tensor_size(remote_tensor)
memory_type = get_icp_memory_type(remote_tensor)
memory_type_id = get_icp_memory_type_id(remote_tensor)
if tensor_size is None or memory_type is None or memory_type_id is None:
raise DataPlaneError("tensor size and memory type must not be none")
set_icp_shape(result, get_icp_shape(remote_tensor))
set_icp_data_type(result, get_icp_data_type(remote_tensor))
set_icp_tensor_size(result, tensor_size)
set_icp_memory_type(result, memory_type)
set_icp_memory_type_id(result, memory_type_id)
if remote_tensor.HasField("contents"):
for value in remote_tensor.contents.bytes_contents:
result.contents.bytes_contents.append(value)
return
tensor_id, host, port = self._split_tensor_uri(remote_tensor)
try:
endpoint = await self._create_endpoint(host, port)
await endpoint.send(numpy.array(tensor_id.bytes))
await endpoint.send(
numpy.array(UCP_DATA_PLANE_COMMANDS.CREATE_REFERENCE, dtype="u1")
)
reference_tensor_id_bytes = numpy.empty(self._id_size, dtype="u1")
await endpoint.recv(reference_tensor_id_bytes)
if not self._keep_endpoints_open:
await self._close_endpoint(host, port)
reference_tensor_id = uuid.UUID(bytes=reference_tensor_id_bytes.tobytes())
set_icp_tensor_uri(result, f"ucp://{host}:{port}/{reference_tensor_id}")
except Exception as e:
raise DataPlaneError("Error Referencing Tensor:\n{remote_tensor}") from e
async def _release_remote_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
):
tensor_id, host, port = self._split_tensor_uri(remote_tensor)
try:
endpoint = await self._create_endpoint(host, port)
await endpoint.send(numpy.array(tensor_id.bytes))
await endpoint.send(
numpy.array(UCP_DATA_PLANE_COMMANDS.RELEASE, dtype="u1")
)
ack_tensor_id = numpy.empty(self._id_size, dtype="u1")
await endpoint.recv(ack_tensor_id)
if not self._keep_endpoints_open:
await self._close_endpoint(host, port)
except Exception as e:
raise DataPlaneError(f"Error Releasing Tensor:\n{remote_tensor}") from e
def get_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
requested_memory_type: Optional[MemoryType] = None,
requested_memory_type_id: Optional[int] = None,
) -> Tensor:
if self._event_loop is None:
raise DataPlaneError("Not connected")
return asyncio.run_coroutine_threadsafe(
self._get_remote_tensor(
remote_tensor, requested_memory_type, requested_memory_type_id
),
self._event_loop,
).result()
def create_input_tensor_reference(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> ModelInferRequest.InferInputTensor:
if self._event_loop is None:
raise DataPlaneError("Not connected")
result = ModelInferRequest.InferInputTensor()
asyncio.run_coroutine_threadsafe(
self._create_remote_tensor_reference(remote_tensor, result),
self._event_loop,
).result()
return result
def create_output_tensor_reference(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> ModelInferResponse.InferOutputTensor:
if self._event_loop is None:
raise DataPlaneError("Not connected")
result = ModelInferResponse.InferOutputTensor()
asyncio.run_coroutine_threadsafe(
self._create_remote_tensor_reference(remote_tensor, result),
self._event_loop,
).result()
return result
def release_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> None:
if remote_tensor.HasField("contents"):
return None
if self._event_loop is None:
raise DataPlaneError("Not connected")
return asyncio.run_coroutine_threadsafe(
self._release_remote_tensor(remote_tensor), self._event_loop
).result()
async def _create_endpoint(self, host: str, port: int):
endpoint = self._endpoints.get((host, port))
if endpoint is None:
LOGGER.debug(f"Creating endpoint for {host}:{port}")
endpoint = await ucp.create_endpoint(host, port)
self._endpoints[(host, port)] = endpoint
else:
LOGGER.debug(f"Reusing endpoint for {host}:{port}")
return endpoint
async def _close_endpoint(self, host: str, port: int):
endpoint = self._endpoints.pop((host, port), None)
if endpoint is not None:
LOGGER.debug(f"Closing endpoint for {host}:{port}")
await endpoint.close()
else:
LOGGER.debug(f"Endpoint for {host}:{port} not found")
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from datetime import datetime
import pytest
from triton_distributed.icp.nats_event_plane import (
EventMetadata,
EventTopic,
NatsEventPlane,
)
pytestmark = pytest.mark.pre_merge
class TestEventTopic:
def test_from_string(self):
topic_str = "level1"
event_topic = EventTopic(topic_str)
assert event_topic.event_topic == topic_str
def test_to_string(self):
event_topic = EventTopic(["level1", "level2"])
assert str(event_topic) == "level1.level2"
class TestEvent:
@pytest.fixture
def sample_event_metadata(self):
event_topic = EventTopic("test.event_topic")
return EventMetadata(
event_id=uuid.uuid4(),
event_topic=event_topic,
event_type="test_event",
timestamp=datetime.utcnow(),
component_id=uuid.uuid4(),
)
class TestEventPlaneNats:
@pytest.fixture
def event_plane_instance(self):
server_url = "tls://localhost:4222"
component_id = uuid.uuid4()
return NatsEventPlane(server_url, component_id)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import dataclasses
import uuid
from typing import List
import pytest
from utils import event_plane, nats_server
from triton_distributed.icp import Event, EventTopic, NatsEventPlane
pytestmark = pytest.mark.pre_merge
@pytest.mark.asyncio
class TestEventPlaneFunctional:
@pytest.mark.asyncio
async def test_single_publisher_subscriber(self, nats_server, event_plane):
print(f"Print loop test: {id(asyncio.get_running_loop())}")
received_events: List[Event] = []
async def callback(event):
received_events.append(event)
print(event)
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"test_payload"
await event_plane.subscribe(
callback, event_topic=event_topic, event_type=event_type
)
event_metadata = await event_plane.publish(event, event_type, event_topic)
# Allow time for message to propagate
await asyncio.sleep(2)
assert len(received_events) == 1
assert received_events[0].event_id == event_metadata.event_id
@pytest.mark.asyncio
async def test_single_publisher_subscriber_iterator(self, nats_server, event_plane):
print(f"Print loop test: {id(asyncio.get_running_loop())}")
received_events: List[Event] = []
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"test_payload"
subscription = await event_plane.subscribe(
event_topic=event_topic, event_type=event_type
)
event_metadata = await event_plane.publish(
event, event_topic=event_topic, event_type=event_type
)
# Allow time for message to propagate
await asyncio.sleep(2)
async for x in subscription:
print(x.timestamp)
print(x.event_id)
print(x.event_type)
print(x.event_topic)
print(x.payload)
received_events.append(x)
break
assert len(received_events) == 1
assert received_events[0].event_id == event_metadata.event_id
@pytest.mark.asyncio
async def test_default_subscription(self, nats_server, event_plane):
print(f"Print loop test: {id(asyncio.get_running_loop())}")
received_events: List[Event] = []
event = b"test_payload"
subscription = await event_plane.subscribe()
event_metadata = await event_plane.publish(
event,
)
# Allow time for message to propagate
await asyncio.sleep(2)
async for x in subscription:
print(x.timestamp)
print(x.event_id)
print(x.event_type)
print(x.event_topic)
print(x.payload)
received_events.append(x)
break
assert len(received_events) == 1
assert received_events[0].event_id == event_metadata.event_id
@pytest.mark.asyncio
async def test_event_topic_list(self, nats_server, event_plane):
print(f"Print loop test: {id(asyncio.get_running_loop())}")
received_events: List[Event] = []
event = b"test_payload"
subscription = await event_plane.subscribe(event_topic="hello")
event_metadata = await event_plane.publish(event, event_topic=["hello"])
# Allow time for message to propagate
await asyncio.sleep(2)
async for x in subscription:
print(x.timestamp)
print(x.event_id)
print(x.event_type)
print(x.event_topic)
print(x.payload)
received_events.append(x)
break
assert len(received_events) == 1
assert received_events[0].event_id == event_metadata.event_id
@pytest.mark.asyncio
async def test_custom_type(self, nats_server, event_plane):
print(f"Print loop test: {id(asyncio.get_running_loop())}")
received_events: List[Event] = []
@dataclasses.dataclass
class MyEvent:
test: str
index: int
event = MyEvent("hello", 0)
subscription = await event_plane.subscribe()
event_metadata = await event_plane.publish(
event,
)
# Allow time for message to propagate
await asyncio.sleep(2)
async for x in subscription:
print(x.timestamp)
print(x.event_id)
print(x.event_type)
print(x.event_topic)
print(x.payload)
print(x.typed_payload(MyEvent))
received_events.append(x)
break
assert len(received_events) == 1
assert received_events[0].event_id == event_metadata.event_id
assert isinstance(received_events[0].typed_payload(MyEvent), type(event))
assert isinstance(received_events[0].typed_payload(dict), dict)
@pytest.mark.asyncio
async def test_one_publisher_multiple_subscribers(self, nats_server):
results_1: List[Event] = []
results_2: List[Event] = []
results_3: List[Event] = []
async def callback_1(event):
results_1.append(event)
async def callback_2(event):
results_2.append(event)
async def callback_3(event):
results_3.append(event)
event_topic = EventTopic(["test"])
event_type = "multi_event"
event = b"multi_payload"
# async with event_plane_context() as event_plane1:
server_url = "tls://localhost:4222"
component_id = uuid.uuid4()
event_plane2 = NatsEventPlane(server_url, component_id)
try:
await event_plane2.connect()
try:
subscription1 = await event_plane2.subscribe(
callback_1, event_topic=event_topic
)
try:
subscription2 = await event_plane2.subscribe(
callback_2, event_topic=event_topic
)
try:
subscription3 = await event_plane2.subscribe(
callback_3, event_type=event_type
)
component_id = uuid.uuid4()
event_plane1 = NatsEventPlane(server_url, component_id)
try:
await event_plane1.connect()
ch1 = EventTopic(["test", "1"])
ch2 = EventTopic(["test", "2"])
await event_plane1.publish(event, event_type, ch1)
await event_plane1.publish(event, event_type, ch2)
# Allow time for message propagation
await asyncio.sleep(2)
assert len(results_1) == 2
assert len(results_2) == 2
assert len(results_3) == 2
finally:
await event_plane1.disconnect()
finally:
await subscription3.unsubscribe()
finally:
await subscription2.unsubscribe()
finally:
await subscription1.unsubscribe()
finally:
await event_plane2.disconnect()
@pytest.mark.asyncio
async def test_context_manager(self, nats_server):
"""Test that context managers properly handle connection/disconnection and subscription/unsubscription."""
received_events: List[Event] = []
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"test_payload"
# Test successful operation with context managers
async with NatsEventPlane() as plane:
assert plane.is_connected()
async def callback(event):
received_events.append(event)
async with await plane.subscribe(
callback, event_topic=event_topic, event_type=event_type
) as subscription:
assert subscription._nc_sub is not None
event_metadata = await plane.publish(event, event_type, event_topic)
await asyncio.sleep(2) # Allow time for message to propagate
# After subscription context, should be unsubscribed
assert subscription._nc_sub is None
# After plane context, should be disconnected
assert not plane.is_connected()
assert len(received_events) == 1
assert received_events[0].event_id == event_metadata.event_id
# Test error handling in context managers
with pytest.raises(RuntimeError):
async with NatsEventPlane() as plane:
async with await plane.subscribe(
callback, event_topic=event_topic, event_type=event_type
):
raise RuntimeError("Test error")
# Should not reach here
pytest.fail("Should have raised exception")
# Should not reach here
pytest.fail("Should have raised exception")
# Even after error, resources should be cleaned up
assert not plane.is_connected()
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import subprocess
import time
from contextlib import asynccontextmanager
import pytest_asyncio
from triton_distributed.icp import (
DEFAULT_EVENTS_HOST,
DEFAULT_EVENTS_PORT,
NatsEventPlane,
)
logger = logging.getLogger(__name__)
def is_port_in_use(port: int) -> bool:
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
@pytest_asyncio.fixture(loop_scope="session")
async def nats_server():
"""Fixture to start and stop a NATS server."""
process = None
try:
# Raise more intuitive error to developer if port is already in-use.
if is_port_in_use(DEFAULT_EVENTS_PORT):
raise RuntimeError(
f"ERROR: NATS Port {DEFAULT_EVENTS_PORT} already in use. Is a nats-server already running?"
)
# Start NATS server
logger.info("NATS server starting")
process = subprocess.Popen(
[
"nats-server",
"-p",
str(DEFAULT_EVENTS_PORT),
"-addr",
DEFAULT_EVENTS_HOST,
],
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
while not is_port_in_use(DEFAULT_EVENTS_PORT):
logger.debug("Waiting for NATS server to start...")
time.sleep(0.2)
logger.info("NATS server started")
yield process
finally:
# Stop the NATS server
if process:
logger.debug("Closing NATS server")
process.terminate()
# communicate() ensures we consume all stdout/stderr so they can close
out, err = process.communicate()
# If you want to log them:
logger.debug("NATS server stdout: %s", out.decode())
logger.debug("NATS server stderr: %s", err.decode())
if process.stdout:
process.stdout.close()
if process.stderr:
process.stderr.close()
# Stop the NATS server
process.wait()
@asynccontextmanager
async def event_plane_context():
# with nats_server_context() as server:
print(f"Print loop plane context: {id(asyncio.get_running_loop())}")
plane = NatsEventPlane()
await plane.connect()
yield plane
await plane.disconnect()
@pytest_asyncio.fixture(loop_scope="function")
async def event_plane():
print(f"Print loop plane: {id(asyncio.get_running_loop())}")
plane = NatsEventPlane()
await plane.connect()
yield plane
await plane.disconnect()
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import uuid
from multiprocessing import Process, Queue
from typing import Sequence
import cupy
import numpy
import pytest
import ucp
from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp.data_plane import DataPlaneError
from triton_distributed.icp.data_type import DATA_TYPE_TO_NUMPY_DTYPE, DataType
from triton_distributed.icp.memory_type import MemoryType
from triton_distributed.icp.tensor import Tensor
from triton_distributed.icp.ucp_data_plane import (
UcpDataPlane,
get_icp_tensor_uri,
set_icp_tensor_uri,
)
# TODO decide if some tests should be removed
# from pre_merge
pytestmark = pytest.mark.pre_merge
def _cuda_available():
# Note: cuda.is_avalailable initializes cuda
# and can't be called when forking subprocesses
try:
return cupy.cuda.is_available()
except CUDARuntimeError:
return False
def data_plane_reader(
input_tensor_queue: Queue,
tensor_descriptor_queue: Queue,
output_tensor_queue: Queue,
memory_type: MemoryType,
memory_type_id: int,
):
ucp.reset()
data_plane = UcpDataPlane()
data_plane.connect()
output_tensor = None
get_error = None
release_error = None
while True:
input_tensor = tensor_descriptor_queue.get()
if input_tensor is None:
break
try:
output_tensor = data_plane.get_tensor(
input_tensor,
requested_memory_type=memory_type,
requested_memory_type_id=memory_type_id,
)
if output_tensor.data_type == DataType.BYTES:
output_tensor = output_tensor.to_bytes_array()
else:
if memory_type == MemoryType.GPU and _cuda_available():
output_tensor = cupy.from_dlpack(output_tensor)
else:
output_tensor = numpy.from_dlpack(output_tensor)
except DataPlaneError as e:
get_error = e
try:
data_plane.release_tensor(input_tensor)
except DataPlaneError as e:
release_error = e
if get_error:
output_tensor_queue.put((get_error, release_error))
else:
output_tensor_queue.put((output_tensor, release_error))
output_tensor_queue.put((None, None))
data_plane.close()
def data_plane_writer(
input_tensor_queue: Queue,
tensor_descriptor_queue: Queue,
output_tensor_queue: Queue,
memory_type: MemoryType,
memory_type_id: int,
use_invalid_descriptor: bool = False,
timeout=30,
use_tensor_contents: bool = False,
):
ucp.reset()
data_plane = UcpDataPlane()
data_plane.connect()
while True:
input_tensor = input_tensor_queue.get()
if input_tensor is None:
tensor_descriptor_queue.put(None)
break
input_tensor = Tensor._from_object(input_tensor)
input_tensor_descriptor = data_plane.put_input_tensor(
input_tensor, use_tensor_contents=use_tensor_contents
)
if use_invalid_descriptor and not use_tensor_contents:
tensor_uri = get_icp_tensor_uri(input_tensor_descriptor)
invalid_tensor_id = str(uuid.uuid1())
tensor_uri = tensor_uri[: -len(invalid_tensor_id)]
tensor_uri = tensor_uri + invalid_tensor_id
set_icp_tensor_uri(input_tensor_descriptor, tensor_uri)
tensor_descriptor_queue.put(input_tensor_descriptor)
if not use_invalid_descriptor and timeout:
data_plane.close(wait_for_release=timeout)
data_plane.close()
@pytest.fixture
def tensors():
tensors = [
numpy.random.randint(0, 10, size=(2, 3)),
numpy.random.randint(0, 10, size=(100)),
numpy.random.randint(2, size=(1), dtype=bool),
]
return tensors
@pytest.mark.timeout(60, method="thread")
def test_data_plane_error_invalid_tensor_uri(request):
input_tensor_queue: Queue = Queue()
tensor_descriptor_queue: Queue = Queue()
output_tensor_queue: Queue = Queue()
input_tensors = []
memory_type = MemoryType.CPU
memory_type_id = 0
reader = Process(
target=data_plane_reader,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
),
)
writer = Process(
target=data_plane_writer,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
True,
30,
),
)
reader.start()
writer.start()
tensors = request.getfixturevalue("tensors")
for tensor in tensors:
input_tensors.append(Tensor.from_dlpack(tensor))
for input_tensor in input_tensors:
if input_tensor.memory_type == MemoryType.CPU or not _cuda_available():
input_tensor_queue.put(numpy.from_dlpack(input_tensor))
else:
input_tensor_queue.put(cupy.from_dlpack(input_tensor))
input_tensor_queue.put(None)
reader.join()
writer.join()
while True:
output_tensor, release_error = output_tensor_queue.get()
if output_tensor is None:
break
assert isinstance(output_tensor, DataPlaneError)
assert isinstance(release_error, DataPlaneError)
@pytest.mark.timeout(30, method="thread")
@pytest.mark.parametrize(
"memory_type,memory_type_id", [(MemoryType.CPU, 0), (MemoryType.GPU, 0)]
)
def test_requested_memory_type(memory_type, memory_type_id, request):
ctx = multiprocessing.get_context("spawn")
input_tensor_queue = ctx.Queue()
tensor_descriptor_queue = ctx.Queue()
output_tensor_queue = ctx.Queue()
input_tensors = []
output_tensors = []
reader = ctx.Process(
target=data_plane_reader,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
),
)
writer = ctx.Process(
target=data_plane_writer,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
False,
30,
),
)
reader.start()
writer.start()
tensors = request.getfixturevalue("tensors")
for tensor in tensors:
input_tensors.append(Tensor.from_dlpack(tensor))
for input_tensor in input_tensors:
if input_tensor.memory_type == MemoryType.CPU or not _cuda_available():
input_tensor_queue.put(numpy.from_dlpack(input_tensor))
else:
input_tensor_queue.put(cupy.from_dlpack(input_tensor))
input_tensor_queue.put(None)
reader.join()
writer.join()
while True:
output_tensor, release_error = output_tensor_queue.get()
if output_tensor is None:
break
assert not isinstance(output_tensor, DataPlaneError)
output_tensors.append(Tensor.from_dlpack(output_tensor))
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
expected_memory_type = memory_type
if not _cuda_available():
expected_memory_type = MemoryType.CPU
assert output_tensor.memory_type == expected_memory_type
assert output_tensor.memory_type_id == memory_type_id
input_comparison = numpy.from_dlpack(input_tensor.to_host())
output_comparison = numpy.from_dlpack(output_tensor.to_host())
numpy.testing.assert_equal(input_comparison, output_comparison)
print(input_tensor, output_tensor)
def _get_random_tensor(data_type: DataType, size: Sequence[int]):
dtype = DATA_TYPE_TO_NUMPY_DTYPE[data_type]
value = numpy.random.rand(*size)
return value.astype(dtype)
@pytest.mark.timeout(30, method="thread")
@pytest.mark.parametrize(
"data_type",
[
data_type
for data_type in DataType.__members__.values()
if data_type not in [DataType.INVALID, DataType.BF16]
],
ids=[
data_type
for data_type in DataType.__members__.keys()
if data_type not in ["INVALID", "BF16"]
],
)
def test_tensor_types(request, data_type):
input_tensor_queue: Queue = Queue()
tensor_descriptor_queue: Queue = Queue()
output_tensor_queue: Queue = Queue()
input_tensors = []
output_tensors = []
memory_type = MemoryType.CPU
memory_type_id = 0
reader = Process(
target=data_plane_reader,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
),
)
writer = Process(
target=data_plane_writer,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
False,
30,
),
)
reader.start()
writer.start()
tensors = []
tensors.append(_get_random_tensor(data_type, [1, 4]))
for tensor in tensors:
if data_type == DataType.BYTES:
input_tensors.append(Tensor._from_object(tensor))
else:
input_tensors.append(Tensor.from_dlpack(tensor))
for input_tensor in input_tensors:
if input_tensor.data_type != DataType.BYTES:
input_tensor_queue.put(numpy.from_dlpack(input_tensor))
else:
input_tensor_queue.put(input_tensor.to_bytes_array())
input_tensor_queue.put(None)
reader.join()
writer.join()
while True:
output_tensor, release_error = output_tensor_queue.get()
if output_tensor is None:
break
assert not isinstance(output_tensor, DataPlaneError)
output_tensors.append(Tensor._from_object(output_tensor))
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
expected_memory_type = memory_type
if not _cuda_available():
expected_memory_type = MemoryType.CPU
assert output_tensor.memory_type == expected_memory_type
assert output_tensor.memory_type_id == memory_type_id
if input_tensor.data_type == DataType.BYTES:
input_comparison = input_tensor.to_bytes_array()
output_comparison = output_tensor.to_bytes_array()
else:
input_comparison = numpy.from_dlpack(input_tensor.to_host())
output_comparison = numpy.from_dlpack(output_tensor.to_host())
numpy.testing.assert_equal(input_comparison, output_comparison)
print(input_tensor, output_tensor)
@pytest.mark.timeout(30, method="thread")
@pytest.mark.parametrize(
"data_type",
[
data_type
for data_type in DataType.__members__.values()
if data_type not in [DataType.INVALID, DataType.BF16]
],
ids=[
data_type
for data_type in DataType.__members__.keys()
if data_type not in ["INVALID", "BF16"]
],
)
def test_use_tensor_contents(request, data_type):
input_tensor_queue: Queue = Queue()
tensor_descriptor_queue: Queue = Queue()
output_tensor_queue: Queue = Queue()
input_tensors = []
output_tensors = []
memory_type = MemoryType.CPU
memory_type_id = 0
reader = Process(
target=data_plane_reader,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
),
)
writer = Process(
target=data_plane_writer,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
True,
30,
True,
),
)
reader.start()
writer.start()
tensors = []
tensors.append(_get_random_tensor(data_type, [2, 4]))
for tensor in tensors:
if data_type == DataType.BYTES:
input_tensors.append(Tensor._from_object(tensor))
else:
input_tensors.append(Tensor.from_dlpack(tensor))
for input_tensor in input_tensors:
if input_tensor.data_type != DataType.BYTES:
input_tensor_queue.put(numpy.from_dlpack(input_tensor))
else:
input_tensor_queue.put(input_tensor.to_bytes_array())
input_tensor_queue.put(None)
reader.join()
writer.join()
while True:
output_tensor, release_error = output_tensor_queue.get()
if output_tensor is None:
break
assert not isinstance(output_tensor, DataPlaneError)
output_tensors.append(Tensor._from_object(output_tensor))
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
expected_memory_type = memory_type
if not _cuda_available():
expected_memory_type = MemoryType.CPU
assert output_tensor.memory_type == expected_memory_type
assert output_tensor.memory_type_id == memory_type_id
if input_tensor.data_type == DataType.BYTES:
input_comparison = input_tensor.to_bytes_array()
output_comparison = output_tensor.to_bytes_array()
else:
input_comparison = numpy.from_dlpack(input_tensor.to_host())
output_comparison = numpy.from_dlpack(output_tensor.to_host())
numpy.testing.assert_equal(input_comparison, output_comparison)
print(input_tensor, output_tensor)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import shutil
import subprocess
import time
import uuid
from multiprocessing import Process, Queue
import pytest
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from triton_distributed.icp.request_plane import (
get_icp_component_id,
set_icp_final_response,
)
NATS_PORT = 4222
# TODO decide if some tests should be removed
# from pre_merge
pytestmark = pytest.mark.pre_merge
def is_port_in_use(port: int) -> bool:
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
@pytest.fixture
def nats_server(request):
command = [
"/usr/local/bin/nats-server",
"--jetstream",
"--debug",
"--trace",
"--port",
str(NATS_PORT),
]
print(f"Running: [{' '.join(command)}]")
# Raise more intuitive error to developer if port is already in-use.
if is_port_in_use(NATS_PORT):
raise RuntimeError(
f"ERROR: NATS Port {NATS_PORT} already in use. Is a nats-server already running?"
)
shutil.rmtree("/tmp/nats", ignore_errors=True)
with open("nats_server.stdout.log", "wt") as output_:
with open("nats_server.stderr.log", "wt") as output_err:
process = subprocess.Popen(
command, stdin=subprocess.DEVNULL, stdout=output_, stderr=output_err
)
time.sleep(1)
yield process
process.terminate()
process.wait()
shutil.rmtree("/tmp/nats", ignore_errors=True)
class ResponseHandler:
def __init__(self, request_plane):
self._request_plane = request_plane
async def response_handler(self, response):
print(response)
request = ModelInferRequest()
request.model_name = response.model_name
request.model_version = response.model_version
print("publishing request")
acks = []
for i in range(5):
acks.append(
self._request_plane.post_request(
request, response_handler=self.response_handler
)
)
asyncio.gather(*acks)
@pytest.mark.timeout(30)
async def test_handler(nats_server):
model_name = str(uuid.uuid1())
model_version = "1"
client_request_plane = NatsRequestPlane()
await client_request_plane.connect()
request = ModelInferRequest()
request.model_name = model_name
request.model_version = model_version
await client_request_plane.post_request(
request, response_handler=ResponseHandler(client_request_plane).response_handler
)
worker_request_plane = NatsRequestPlane()
await worker_request_plane.connect()
request_count = 10
while request_count > 0:
requests = await worker_request_plane.pull_requests(
model_name, model_version, 100, 0.1
)
acks = []
async for request in requests:
request_count -= 1
response = ModelInferResponse()
set_icp_final_response(response, True)
acks.append(worker_request_plane.post_response(request, response))
print(request_count)
await asyncio.gather(*acks)
await asyncio.sleep(0.1)
requests = await worker_request_plane.pull_requests(
model_name, model_version, 100, 0.1
)
await worker_request_plane.close()
await client_request_plane.close()
def run_request_generator(request_queue, response_queue, direct_requests=False):
asyncio.run(
request_generator(
request_queue, response_queue, direct_requests=direct_requests
)
)
async def request_generator(request_queue, response_queue, direct_requests=False):
# Generate requests and wait for responses
# if direct_requests == True, then send all requests to the
# worker that responds to the first request
request_plane = NatsRequestPlane()
await request_plane.connect()
target_component_id = None
while True:
request_bytes = request_queue.get()
if request_bytes is None:
response_queue.put(None)
break
request = ModelInferRequest()
request.ParseFromString(request_bytes)
async for response in await request_plane.post_request(
request, response_iterator=True, component_id=target_component_id
):
if direct_requests:
target_component_id = get_icp_component_id(response)
print(response)
response_queue.put(response.SerializeToString())
def run_worker(model_name, model_version, batch_size, request_count, pull_timeout=0.1):
asyncio.run(
worker(model_name, model_version, batch_size, request_count, pull_timeout)
)
async def worker(
model_name, model_version, batch_size, request_count, pull_timeout=0.1
):
request_plane = NatsRequestPlane()
await request_plane.connect()
while request_count:
requests = await request_plane.pull_requests(
model_name, model_version, batch_size, pull_timeout
)
acks = []
async for request in requests:
print(request)
request_count -= 1
response = ModelInferResponse()
set_icp_final_response(response, True)
acks.append(request_plane.post_response(request, responses=response))
await asyncio.gather(*acks)
@pytest.mark.timeout(30)
async def test_iterator(nats_server):
batch_size = 10
request_count = 100
model_name = str(uuid.uuid1())
model_version = "1"
request_queue: Queue = Queue()
response_queue: Queue = Queue()
generator_process = Process(
target=run_request_generator, args=(request_queue, response_queue)
)
worker_process = Process(
target=run_worker, args=(model_name, model_version, batch_size, request_count)
)
generator_process.start()
worker_process.start()
for index in range(request_count):
request_queue.put(
ModelInferRequest(
model_name=model_name, model_version=model_version, id=str(index)
).SerializeToString()
)
request_queue.put(None)
generator_process.join()
worker_process.join()
response_count = 0
while True:
response = response_queue.get()
if response is None:
break
response_count += 1
assert request_count == response_count
@pytest.mark.parametrize("pull_timeout,batch_size", [(0.1, 10), (None, 1)])
@pytest.mark.timeout(30)
async def test_direct_requests(nats_server, pull_timeout, batch_size):
request_count = 100
model_name = str(uuid.uuid1())
model_version = "1"
request_queue: Queue = Queue()
response_queue: Queue = Queue()
# Note with direct_requests == True
# all requests should target a single worker
# and all responses should be from a single worker
generator_process = Process(
target=run_request_generator,
args=(request_queue, response_queue),
kwargs={"direct_requests": True},
)
worker_process_1 = Process(
target=run_worker,
args=(model_name, model_version, batch_size, request_count, pull_timeout),
)
worker_process_2 = Process(
target=run_worker,
args=(model_name, model_version, batch_size, request_count, pull_timeout),
)
worker_process_1.start()
worker_process_2.start()
time.sleep(1)
generator_process.start()
for index in range(request_count):
request_queue.put(
ModelInferRequest(
model_name=model_name, model_version=model_version, id=str(index)
).SerializeToString()
)
request_queue.put(None)
generator_process.join()
worker_process_1.terminate()
worker_process_1.join()
worker_process_2.terminate()
worker_process_2.join()
response_count = 0
responders = set()
while True:
request_bytes = response_queue.get()
if request_bytes is None:
break
response = ModelInferResponse()
response.ParseFromString(request_bytes)
response_count += 1
responders.add(get_icp_component_id(response))
assert len(responders) == 1
assert request_count == response_count
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[build-system]
requires = ["setuptools>=65.0", "setuptools-scm>=8"]
build-backend = "setuptools.build_meta"
[project]
name = "triton-distributed-runtime"
dynamic = ["version"]
authors = [
{ name = "NVIDIA Inc.", email = "sw-dl-triton@nvidia.com" },
]
license = { text = "Apache-2.0" }
dependencies = ["triton_distributed.icp >= 0"]
[tool.setuptools_scm]
version_file = "src/triton_distributed/runtime/_version.py"
root = "../.."
[tool.setuptools.packages.find]
where = ["src"]
include = ["triton_distributed.runtime*"]
namespaces = true
[tool.setuptools]
license-files = ["../../LICENSE"]
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from triton_distributed.runtime.deployment import Deployment as Deployment
from triton_distributed.runtime.logger import get_logger as get_logger
from triton_distributed.runtime.logger import get_logger_config as get_logger_config
from triton_distributed.runtime.operator import Operator as Operator
from triton_distributed.runtime.operator import OperatorConfig as OperatorConfig
from triton_distributed.runtime.remote_operator import RemoteOperator as RemoteOperator
from triton_distributed.runtime.remote_request import (
RemoteInferenceRequest as RemoteInferenceRequest,
)
from triton_distributed.runtime.remote_response import (
RemoteInferenceResponse as RemoteInferenceResponse,
)
try:
from triton_distributed.runtime.triton_core_operator import (
TritonCoreOperator as TritonCoreOperator,
)
except ImportError:
pass
from triton_distributed.runtime.worker import Worker as Worker
from triton_distributed.runtime.worker import WorkerConfig as WorkerConfig
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
from pprint import pformat
from typing import Optional, Type
from triton_distributed.icp import (
DataPlane,
NatsRequestPlane,
NatsServer,
RequestPlane,
UcpDataPlane,
)
from triton_distributed.runtime.logger import get_logger
from triton_distributed.runtime.worker import Worker, WorkerConfig
LOGGER_NAME = __name__
class Deployment:
def __init__(
self,
worker_configs: list[WorkerConfig | tuple[WorkerConfig, int]],
log_level=3,
initialize_request_plane=False,
initialize_data_plane=False,
request_plane_args: Optional[tuple[list, dict]] = None,
request_plane: Optional[Type[RequestPlane]] = NatsRequestPlane,
data_plane: Optional[Type[DataPlane]] = UcpDataPlane,
data_plane_args: Optional[tuple[list, dict]] = None,
log_dir="logs",
consolidate_logs=False,
starting_metrics_port=0,
):
self._process_context = multiprocessing.get_context("spawn")
self._worker_configs = worker_configs
self._workers: list[multiprocessing.context.SpawnProcess] = []
self._logger = get_logger(log_level, LOGGER_NAME)
self._default_request_plane = request_plane
self._default_request_plane_args = request_plane_args
self._default_data_plane = data_plane
self._default_data_plane_args = data_plane_args
self._initialize_request_plane = initialize_request_plane
self._initialize_data_plane = initialize_data_plane
self.request_plane_server: NatsServer = None
self._default_log_dir = log_dir
self._default_log_level = log_level
self._consolidate_logs = consolidate_logs
self._starting_metrics_port = starting_metrics_port
@staticmethod
def _start_worker(worker_config):
Worker(worker_config).start()
def start(self):
if self._initialize_request_plane:
if self._default_request_plane == NatsRequestPlane:
self.request_plane_server = NatsServer(log_dir=self._default_log_dir)
else:
raise ValueError(
f"Unknown Request Plane Type, can not initialize {self._default_request_plane}"
)
for worker_config in self._worker_configs:
worker_instances = 1
if isinstance(worker_config, tuple):
worker_instances = worker_config[1]
worker_config = worker_config[0]
base_name = worker_config.name
base_port = worker_config.metrics_port
if not base_port and self._starting_metrics_port:
base_port = self._starting_metrics_port
self._starting_metrics_port += worker_instances
request_plane_args, request_plane_kwargs = worker_config.request_plane_args
if not request_plane_args and not request_plane_kwargs:
if self._default_request_plane_args:
worker_config.request_plane_args = self._default_request_plane_args
elif self.request_plane_server:
worker_config.request_plane_args = (
[self.request_plane_server.url],
{},
)
if not worker_config.log_dir:
worker_config.log_dir = self._default_log_dir
if not worker_config.log_level:
worker_config.log_level = self._default_log_level
if self._consolidate_logs:
worker_config.consolidate_logs = True
for index in range(worker_instances):
worker_config.name = f"{base_name}.{index}"
worker_config.metrics_port = base_port + index
self._workers.append(
self._process_context.Process(
target=Deployment._start_worker,
name=worker_config.name,
args=[worker_config],
)
)
self._logger.info(
"\n\nStarting Worker:\n\n\tConfig:\n\t%s\n\t%s\n",
pformat(worker_config),
self._workers[-1],
)
self._workers[-1].start()
def stop(self):
return self.shutdown()
def shutdown(self, join=True, timeout=10):
exit_code = 0
for worker in self._workers:
self._logger.info("\n\nStopping Worker:\n\n\n\t%s\n", worker)
worker.terminate()
if join:
for worker in self._workers:
worker.join(timeout)
for worker in self._workers:
if worker.is_alive():
worker.kill()
worker.join(timeout)
self._logger.info("\n\nWorker Stopped:\n\n\n\t%s\n", worker)
if worker.exitcode is not None:
# Note we accumulate exit codes
# assumption being no error is exit_code==0
# anything else represents an error
#
# this is to catch some obvious errors but not all
exit_code += worker.exitcode
return exit_code
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import logging.config
from typing import Any
_LOGGER_NAME = "Triton Distributed Runtime"
_FHANDLER_CONFIG_TEMPLATE = {
"class": "logging.FileHandler",
"formatter": "standard",
}
_LOGGER_CONFIG_TEMPLATE = {"handlers": ["console"], "propagate": True}
_LOGGING_CONFIG_TEMPLATE = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"standard": {
"format": "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
"datefmt": "%H:%M:%S",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "standard",
"stream": "ext://sys.stdout",
}
},
}
def get_logger_config(log_level=1, logger_name=_LOGGER_NAME, log_file=None):
config_dict: dict[str, Any] = _LOGGING_CONFIG_TEMPLATE
front = "%(asctime)s.%(msecs)03d %(filename)s:%(lineno)s"
config_dict["formatters"]["standard"][
"format"
] = f"{front} [{logger_name}] %(levelname)s: %(message)s"
if log_file:
fh_config_dict = _FHANDLER_CONFIG_TEMPLATE
fh_config_dict["filename"] = str(log_file)
config_dict["handlers"]["file"] = fh_config_dict
logger_config: dict[str, Any] = _LOGGER_CONFIG_TEMPLATE
if log_file:
logger_config["handlers"].append("file")
config_dict["loggers"] = {}
config_dict["loggers"][logger_name] = logger_config
return config_dict
# TODO: Add support for taking logging level as input as well.
def get_logger(log_level=1, logger_name=_LOGGER_NAME, log_file=None):
if log_level == 0:
level = logging.ERROR
elif log_level == 1:
level = logging.INFO
else:
level = logging.DEBUG
config_dict = get_logger_config(log_level, logger_name, log_file)
logging.config.dictConfig(config_dict)
logger = logging.getLogger(logger_name)
logger.setLevel(level=level)
return logger
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
from triton_distributed.runtime.parser import Parser
from triton_distributed.runtime.worker import Worker
def main(args=None):
args, cli_parser = Parser.parse_args(args)
# TODO: Revisit the worklow args. To simplify.
worker = Worker(
request_plane=NatsRequestPlane(args.request_plane_uri),
data_plane=UcpDataPlane(),
log_level=args.log_level,
operators=cli_parser.operator_configs,
metrics_port=args.metrics_port,
log_dir=args.log_dir,
name=args.name,
triton_log_path=args.triton_log_path,
)
worker.start()
if __name__ == "__main__":
sys.exit(main())
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface for Operators"""
import abc
from dataclasses import dataclass, field
from typing import Any, Optional, Type
try:
from tritonserver import Server as TritonCore
TRITON_CORE_AVAILABLE = True
except ImportError:
TRITON_CORE_AVAILABLE = False
TritonCore = type(None) # type: ignore[misc,assignment]
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.runtime.remote_request import RemoteInferenceRequest
class Operator(abc.ABC):
@abc.abstractmethod
def __init__(
self,
name: str,
version: int,
request_plane: RequestPlane,
data_plane: DataPlane,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
repository: Optional[str] = None,
logger: Optional[Any] = None,
triton_core: Optional[TritonCore] = None,
):
pass
@abc.abstractmethod
async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
pass
@dataclass
class OperatorConfig:
"""
Holds the properties of a hosted operator
"""
name: str
implementation: str | Type[Operator]
repository: Optional[str] = None
version: int = 1
max_inflight_requests: int = 5
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
)
log_level: Optional[int] = None
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from triton_distributed.runtime.worker import OperatorConfig
# Default values
DEFAULT_REQUEST_PLANE_URI = "nats://localhost:4222"
DEFAULT_LOG_LEVEL = 0
# Property keys
NAME = "name"
VERSION = "version"
MAX_INFLIGHT_REQUESTS = "max_inflight_requests"
PARAMETERS = "parameters"
MODULE = "module"
REPOSITORY = "repository"
IMPLEMENTATION = "implementation"
class InvalidArgumentError(Exception):
pass
def _parse_name_and_properties(args, valid_properties):
kind = "operator"
args_dict = {}
if len(args) == 1:
args_dict[NAME] = args[0]
else:
for arg in args:
values = arg.split(":")
if values[0] not in valid_properties:
raise InvalidArgumentError(
f"Unexpected property found for `--{kind}` found. Expected one of {valid_properties}, found {values[0]}"
)
args_dict[values[0]] = ":".join(values[1:])
if values[0] == PARAMETERS:
parameter_file_path = args_dict[values[0]]
if not os.path.exists(parameter_file_path):
args_dict[values[0]] = json.loads(args_dict[values[0]])
else:
with open(parameter_file_path, "r") as f:
args_dict[values[0]] = json.load(f)
if NAME not in args_dict.keys():
raise InvalidArgumentError(
f"`name` is a required property for `--{kind}`. Missing `name:<{kind}_name>`, in {args}"
)
return args_dict
def _validate_operator_args(operator_args):
valid_properties = [
NAME,
VERSION,
MAX_INFLIGHT_REQUESTS,
MODULE,
PARAMETERS,
REPOSITORY,
]
properties = _parse_name_and_properties(operator_args, valid_properties)
for int_property in [VERSION, MAX_INFLIGHT_REQUESTS]:
if int_property in properties.keys():
try:
int(properties[int_property])
except ValueError:
raise InvalidArgumentError(
f"Unexpected value provided for `{int_property}` for operator `{properties[NAME]}`. Expected an integer, Got {properties[int_property]}"
)
if MODULE not in properties.keys():
raise InvalidArgumentError(
f"{MODULE} property not provided for operator `{properties[NAME]}`. This is a required property."
)
properties[IMPLEMENTATION] = properties[MODULE]
properties.pop(MODULE)
return properties
class Parser:
@classmethod
def _validate_args(cls, args):
operator_configs: list[OperatorConfig] = []
for operator_args in args.operators:
operator_properties = _validate_operator_args(operator_args)
operator_configs.append(OperatorConfig(**operator_properties))
args.operator_configs = operator_configs
# TODO: Add validation for request plane URI
@classmethod
def parse_args(cls, args=None):
parser = argparse.ArgumentParser(description="Triton Worker Component")
parser.add_argument(
"-c",
"--request-plane-uri",
type=str,
default=DEFAULT_REQUEST_PLANE_URI,
help="Request plane URI for the worker",
)
parser.add_argument(
"-l",
"--log-level",
type=int,
default=DEFAULT_LOG_LEVEL,
help="The logging level for Triton. The verbose logging can be enabled by specifying a value >= 1.",
)
parser.add_argument(
"--log-dir",
type=str,
default=None,
help="log dir folder",
)
parser.add_argument(
"--triton-log-path",
type=str,
default=None,
help="triton log path",
)
parser.add_argument(
"--name",
type=str,
default=None,
help="worker name",
)
parser.add_argument(
"-op",
"--operator",
type=str,
action="append",
nargs="+",
default=[],
dest="operators",
help="The operator to be hosted in the worker. The option can accept a single argument for the model name to load. Alternatively, it can also accept optional arguments in format `name:<model_name> version:<model_version>(optional) batch_size:<batch_size>(optional)`",
)
parser.add_argument(
"--metrics-port",
type=int,
default=0,
help="enable prometheus metrics for worker",
)
"""
TODO: Add more options as per requirements
"""
args = parser.parse_args(args)
try:
Parser._validate_args(args)
except Exception as err:
parser.error(f"Failed to validate arguments {err=}, {type(err)=}")
return args, cls
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for interacting with Triton Inference Server Models"""
import asyncio
import uuid
from typing import Optional
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.runtime.remote_request import RemoteInferenceRequest
from triton_distributed.runtime.remote_response import AsyncRemoteResponseIterator
class RemoteOperator:
def __init__(
self,
operator: str | tuple[str, int],
request_plane: RequestPlane,
data_plane: DataPlane,
component_id: Optional[uuid.UUID] = None,
):
if isinstance(operator, str):
self.name = operator
self.version = 1
else:
self.name = operator[0]
self.version = operator[1]
self._request_plane = request_plane
self._data_plane = data_plane
self.component_id = component_id
@property
def data_plane(self):
return self._data_plane
def create_request(self, **kwargs) -> RemoteInferenceRequest:
if "model_name" in kwargs:
kwargs.pop("model_name")
if "model_version" in kwargs:
kwargs.pop("model_version")
if "data_plane" in kwargs:
kwargs.pop("data_plane")
if "_request_plane" in kwargs:
kwargs.pop("_request_plane")
if "_model_infer_request" in kwargs:
kwargs.pop("_model_infer_request")
return RemoteInferenceRequest(
model_name=self.name,
model_version=self.version,
data_plane=self._data_plane,
_request_plane=None,
_model_infer_request=None,
**kwargs,
)
async def async_infer(
self,
inference_request: Optional[RemoteInferenceRequest] = None,
raise_on_error: bool = True,
**kwargs,
) -> AsyncRemoteResponseIterator:
if inference_request is None:
inference_request = RemoteInferenceRequest(
model_name=self.name,
model_version=self.version,
data_plane=self.data_plane,
_request_plane=None,
_model_infer_request=None,
**kwargs,
)
else:
inference_request.model_name = self.name
inference_request.model_version = self.version
if inference_request.data_plane != self.data_plane:
raise ValueError(
"Data plane mismatch between remote request and remote operator: \n\n Operator: {self.data_plane} \n\n Request: {inference_request.data_plane}"
)
if (inference_request.response_queue is not None) and (
not isinstance(inference_request.response_queue, asyncio.Queue)
):
raise ValueError("asyncio.Queue must be used for async response iterator")
response_iterator = AsyncRemoteResponseIterator(
self._data_plane,
inference_request,
inference_request.response_queue,
raise_on_error,
)
remote_inference_request = inference_request.to_model_infer_request()
await self._request_plane.post_request(
remote_inference_request,
response_handler=response_iterator._response_handler,
component_id=self.component_id,
)
return response_iterator
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for sending inference requests to Triton Inference Server Models"""
from __future__ import annotations
import asyncio
import queue
import uuid
from collections import Counter
from dataclasses import dataclass, field
from typing import Any, Optional
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest
from triton_distributed.icp.request_plane import RequestPlane, get_icp_component_id
from triton_distributed.icp.tensor import Tensor
from triton_distributed.runtime.remote_response import RemoteInferenceResponse
from triton_distributed.runtime.remote_tensor import RemoteTensor
@dataclass
class RemoteInferenceRequest:
model_name: str
model_version: int
data_plane: DataPlane
component_id: Optional[uuid.UUID] = None
_request_plane: Optional[RequestPlane] = None
_model_infer_request: Optional[ModelInferRequest] = None
request_id: Optional[str] = None
correlation_id: Optional[int | str] = None
priority: Optional[int] = None
timeout: Optional[int] = None
inputs: dict[str, RemoteTensor | Any] = field(default_factory=dict)
store_inputs_in_request: set[str] = field(default_factory=set)
parameters: dict[str, str | int | bool | float] = field(default_factory=dict)
response_queue: Optional[queue.SimpleQueue | asyncio.Queue] = None
def _set_model_infer_request_inputs(
self,
remote_request: ModelInferRequest,
):
for name, value in self.inputs.items():
if not isinstance(value, RemoteTensor):
if not isinstance(value, Tensor):
tensor = Tensor._from_object(value)
else:
tensor = value
use_tensor_contents = name in self.store_inputs_in_request
remote_input = self.data_plane.put_input_tensor(
tensor, use_tensor_contents=use_tensor_contents
)
else:
remote_input = self.data_plane.create_input_tensor_reference(
value.remote_tensor
)
remote_input.name = name
remote_request.inputs.append(remote_input)
def _set_model_infer_request_parameters(self, remote_request: ModelInferRequest):
for key, value in self.parameters.items():
remote_value = remote_request.parameters[key]
if isinstance(value, str):
remote_value.string_param = value
elif isinstance(value, int):
remote_value.int64_param = value
elif isinstance(value, float):
remote_value.double_param = value
elif isinstance(value, bool):
remote_value.bool_param = value
else:
raise ValueError(f"Invalid parameter type: {type(value)}")
@staticmethod
def _set_parameters_from_model_infer_request(
result: RemoteInferenceRequest,
inference_request: ModelInferRequest,
):
for name, value in inference_request.parameters.items():
if value.HasField("bool_param"):
result.parameters[name] = value.bool_param
elif value.HasField("int64_param"):
result.parameters[name] = value.int64_param
elif value.HasField("double_param"):
result.parameters[name] = value.double_param
elif value.HasField("string_param"):
result.parameters[name] = value.string_param
@staticmethod
def _set_inputs_from_model_infer_request(
result: RemoteInferenceRequest,
inference_request: ModelInferRequest,
):
for remote_input in inference_request.inputs:
result.inputs[remote_input.name] = RemoteTensor(
remote_input, result.data_plane
)
def cancel(self):
raise NotImplementedError("Cancel not implemented")
def response_sender(self):
if self._request_plane is None or self._model_infer_request is None:
raise ValueError(
"Response only valid for requests received from request plane"
)
return RemoteResponseSender(
self._model_infer_request, self._request_plane, self.data_plane
)
@staticmethod
def from_model_infer_request(
request: ModelInferRequest, data_plane: DataPlane, request_plane: RequestPlane
) -> RemoteInferenceRequest:
result = RemoteInferenceRequest(
request.model_name,
int(request.model_version),
data_plane,
_request_plane=request_plane,
_model_infer_request=request,
)
if request.id is not None:
result.request_id = request.id
result.component_id = get_icp_component_id(request)
if "sequence_id" in request.parameters:
if request.parameters["sequence_id"].HasField("string_param"):
result.correlation_id = request.parameters["sequence_id"].string_param
else:
result.correlation_id = request.parameters["sequence_id"].int64_param
if "priority" in request.parameters:
result.priority = request.parameters["priority"].uint64_param
if "timeout" in request.parameters:
result.timeout = request.parameters["timeout"].uint64_param
RemoteInferenceRequest._set_inputs_from_model_infer_request(result, request)
RemoteInferenceRequest._set_parameters_from_model_infer_request(result, request)
return result
def to_model_infer_request(self) -> ModelInferRequest:
remote_request = ModelInferRequest()
remote_request.model_name = self.model_name
remote_request.model_version = str(self.model_version)
if self.request_id is not None:
remote_request.id = self.request_id
if self.priority is not None:
remote_request.parameters["priority"].uint64_param = self.priority
if self.timeout is not None:
remote_request.parameters["timeout"].uint64_param = self.timeout
if self.correlation_id is not None:
if isinstance(self.correlation_id, str):
remote_request.parameters[
"sequence_id"
].string_param = self.correlation_id
else:
remote_request.parameters[
"sequence_id"
].int64_param = self.correlation_id
self._set_model_infer_request_inputs(remote_request)
self._set_model_infer_request_parameters(remote_request)
return remote_request
class RemoteResponseSender:
response_counts: Counter = Counter()
def __init__(
self,
model_infer_request: ModelInferRequest,
request_plane: RequestPlane,
data_plane: DataPlane,
):
self._model_infer_request = model_infer_request
self._request_plane = request_plane
self._data_plane = data_plane
def create_response(self, **kwargs) -> RemoteInferenceResponse:
if "model_name" in kwargs:
kwargs.pop("model_name")
if "model_version" in kwargs:
kwargs.pop("model_version")
if "request_id" in kwargs:
kwargs.pop("request_id")
return RemoteInferenceResponse(
model_name=self._model_infer_request.model_name,
model_version=self._model_infer_request.model_version,
request_id=self._model_infer_request.id,
**kwargs,
)
async def send(
self, inference_response: Optional[RemoteInferenceResponse] = None, **kwargs
) -> None:
if inference_response is None:
inference_response = RemoteInferenceResponse(
model_name=self._model_infer_request.model_name,
model_version=self._model_infer_request.model_version,
request_id=self._model_infer_request.id,
**kwargs,
)
await self._request_plane.post_response(
self._model_infer_request,
inference_response.to_model_infer_response(self._data_plane),
)
if inference_response.final:
RemoteResponseSender.response_counts[
(
self._model_infer_request.model_name,
self._model_infer_request.model_version,
)
] += 1
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for receiving inference responses to Triton Distributed Operators"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.protos.icp_pb2 import ModelInferResponse
if TYPE_CHECKING:
from triton_distributed.runtime.remote_request import RemoteInferenceRequest
try:
from tritonserver import Tensor as TritonTensor
except ImportError:
TritonTensor = type(None) # type: ignore [misc, assignment]
import uuid
from triton_distributed.icp.request_plane import (
RequestPlaneError,
get_icp_component_id,
get_icp_final_response,
get_icp_response_error,
set_icp_final_response,
set_icp_response_error,
)
from triton_distributed.icp.tensor import Tensor
from triton_distributed.runtime.logger import get_logger
from triton_distributed.runtime.remote_tensor import RemoteTensor
logger = get_logger(__name__)
class AsyncRemoteResponseIterator:
"""Asyncio compatible response iterator
Response iterators are returned from model inference methods and
allow users to process inference responses in the order they were
received for a request.
"""
def __init__(
self,
data_plane: DataPlane,
request: RemoteInferenceRequest,
user_queue: Optional[asyncio.Queue] = None,
raise_on_error: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Initialize AsyncResponseIterator
AsyncResponseIterator objects are obtained from Model inference
methods and not instantiated directly. See `Model` documentation.
Parameters
----------
model : Model
model associated with inference request
request : TRITONSERVER_InferenceRequest
Underlying C binding TRITONSERVER_InferenceRequest
object. Private.
user_queue : Optional[asyncio.Queue]
Optional user queue for responses in addition to internal
iterator queue.
raise_on_error : bool
if True response errors will be raised as exceptions.
loop : Optional[asyncio.AbstractEventLoop]
asyncio loop object
"""
if loop is None:
loop = asyncio.get_running_loop()
self._loop = loop
self._queue: asyncio.Queue = asyncio.Queue()
self._user_queue = user_queue
self._complete = False
self._request = request
self._data_plane = data_plane
self._raise_on_error = raise_on_error
def __aiter__(self) -> AsyncRemoteResponseIterator:
"""Return async iterator. For use with async for loops.
Returns
-------
AsyncResponseIterator
Examples
--------
responses = server.model("test").async_infer(inputs={"fp16_input":numpy.array([[1]],dtype=numpy.float16)})
async for response in responses:
print(numpy.from_dlpack(response.outputs["fp16_output"]))
"""
return self
async def __anext__(self):
"""Returns the next response received for a request
Returns the next response received for a request as an
awaitable object.
Raises
------
response.error
If raise_on_error is set to True, response errors are
raised as exceptions
StopAsyncIteration
Indicates all responses for a request have been received.
Final responses may or may not include outputs and must be
checked.
"""
if self._complete:
raise StopAsyncIteration
response = await self._queue.get()
self._complete = response.final
if response.error is not None and self._raise_on_error:
raise response.error
return response
def cancel(self) -> None:
"""Cancel an inflight request
Cancels an in-flight request. Cancellation is handled on a
best effort basis and may not prevent execution of a request
if it is already started or completed.
See c:func:`TRITONSERVER_ServerInferenceRequestCancel`
Examples
--------
responses = server.model("test").infer(inputs={"text_input":["hello"]})
responses.cancel()
"""
if self._request is not None:
self._request.cancel()
def _response_handler(self, response: ModelInferResponse):
try:
if self._request is None:
raise ValueError("Response received after final response flag")
final = False
if response is None or get_icp_final_response(response):
final = True
remote_response = RemoteInferenceResponse.from_model_infer_response(
self._request, response, self._data_plane, final
)
asyncio.run_coroutine_threadsafe(
self._queue.put(remote_response), self._loop
)
if self._user_queue is not None:
asyncio.run_coroutine_threadsafe(
self._user_queue.put(remote_response), self._loop
)
if final:
del self._request
self._request = None
except Exception as e:
message = f"Catastrophic failure in response callback: {e}"
logger.exception(message)
# catastrophic failure
raise e from None
@dataclass
class RemoteInferenceResponse:
"""Dataclass representing an inference response.
Inference response objects are returned from response iterators
which are in turn returned from model inference methods. They
contain output data, output parameters, any potential errors
reported and a flag to indicate if the response is the final one
for a request.
Parameters
----------
model : Model
Model instance associated with the response.
request_id : Optional[str], default None
Unique identifier for the inference request (if provided)
parameters : dict[str, str | int | bool], default {}
Additional parameters associated with the response.
outputs : dict [str, Tensor], default {}
Output tensors for the inference.
error : Optional[RequestPlaneError], default None
Error (if any) that occurred in the processing of the request.
classification_label : Optional[str], default None
Classification label associated with the inference. Not currently supported.
final : bool, default False
Flag indicating if the response is final
"""
model_name: str
model_version: int
component_id: Optional[uuid.UUID] = None
request_id: Optional[str] = None
parameters: dict[str, str | int | bool] = field(default_factory=dict)
outputs: dict[str, RemoteTensor | Tensor] = field(default_factory=dict)
store_outputs_in_response: set[str] = field(default_factory=set)
error: Optional[RequestPlaneError] = None
classification_label: Optional[str] = None
final: bool = False
def _set_parameters_from_model_infer_response_parameters(
self, response: ModelInferResponse
):
for name, value in response.parameters.items():
if value.HasField("string_param"):
self.parameters[name] = value.string_param
elif value.HasField("int64_param"):
self.parameters[name] = value.int64_param
elif value.HasField("double_param"):
self.parameters[name] = value.double_param
elif value.HasField("bool_param"):
self.parameters[name] = value.bool_param
def _set_model_infer_response_outputs(
self, response: ModelInferResponse, data_plane: DataPlane
):
for name, value in self.outputs.items():
if not isinstance(value, RemoteTensor):
if not isinstance(value, Tensor) and not isinstance(
value, TritonTensor
):
tensor = Tensor._from_object(value)
else:
tensor = value
use_tensor_contents = name in self.store_outputs_in_response
remote_output = data_plane.put_output_tensor(
tensor, use_tensor_contents=use_tensor_contents
)
else:
remote_output = data_plane.create_output_tensor_reference(
value.remote_tensor
)
remote_output.name = name
response.outputs.append(remote_output)
def _set_model_infer_response_parameters(self, response: ModelInferResponse):
for key, value in self.parameters.items():
remote_value = response.parameters[key]
if isinstance(value, str):
remote_value.string_param = value
elif isinstance(value, int):
remote_value.int64_param = value
elif isinstance(value, float):
remote_value.double_param = value
elif isinstance(value, bool):
remote_value.bool_param = value
def to_model_infer_response(self, data_plane: DataPlane):
remote_response = ModelInferResponse()
remote_response.model_name = self.model_name
remote_response.model_version = str(self.model_version)
if self.request_id:
remote_response.id = self.request_id
if self.error:
set_icp_response_error(remote_response, self.error)
if self.final:
set_icp_final_response(remote_response, self.final)
self._set_model_infer_response_parameters(remote_response)
self._set_model_infer_response_outputs(remote_response, data_plane)
return remote_response
@staticmethod
def from_model_infer_response(
request: RemoteInferenceRequest,
response: ModelInferResponse,
data_plane: DataPlane,
final_response: bool,
) -> RemoteInferenceResponse:
result = RemoteInferenceResponse(
request.model_name,
request.model_version,
None,
request.request_id,
final=final_response,
)
if response is None:
return result
result.request_id = response.id
result.component_id = get_icp_component_id(response)
outputs = {}
for output in response.outputs:
outputs[output.name] = RemoteTensor(output, data_plane)
result.outputs = outputs
result._set_parameters_from_model_infer_response_parameters(response)
result.error = get_icp_response_error(response)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment