"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "3983830e808167551bfa66a84e4476fe8f7212f6"
Commit fd42de29 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

feat: adding initial icp

parent f753fffc
...@@ -63,11 +63,14 @@ repos: ...@@ -63,11 +63,14 @@ repos:
- id: requirements-txt-fixer - id: requirements-txt-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-mypy # NOTE: removing from pre commit
rev: v1.13.0 # will move to gitlab ci to run in proper
hooks: # container
- id: mypy #- repo: https://github.com/pre-commit/mirrors-mypy
exclude: model.py # WAR errors about 'model.py' duplicate module name # rev: v1.13.0
# hooks:
# - id: mypy
# exclude: model.py # WAR errors about 'model.py' duplicate module name
# Fast linting # Fast linting
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
......
...@@ -24,6 +24,8 @@ opentelemetry-sdk ...@@ -24,6 +24,8 @@ opentelemetry-sdk
pre-commit pre-commit
protobuf==5.27.3 protobuf==5.27.3
pyright pyright
pytest-md-report
pytest-mypy
sentencepiece sentencepiece
transformers transformers
tritonclient==2.53.0 tritonclient==2.53.0
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
PROTO_SRC=$(dirname "$(realpath $0)") PROTO_SRC=$(dirname "$(realpath $0)")
SOURCE_ROOT="$(realpath "${PROTO_SRC}/..")" SOURCE_ROOT="$(realpath "${PROTO_SRC}/..")"
PROTO_OUT=$SOURCE_ROOT/src/python/tdist/icp/protos PROTO_OUT=$SOURCE_ROOT/src/python/triton_distributed/icp/protos
mkdir -p $PROTO_OUT mkdir -p $PROTO_OUT
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
syntax = "proto3"; syntax = "proto3";
package tdist.icp; package triton.distributed.icp;
//@@ //@@
//@@.. cpp:var:: message ModelInferRequest //@@.. cpp:var:: message ModelInferRequest
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract Class for interacting with Triton Inference Serving Platform Inter-Component Protocol Data Plane"""
import abc
import uuid
from typing import Optional, Sequence
import cupy
import numpy
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from tritonserver import (
DataType,
InvalidArgumentError,
MemoryBuffer,
MemoryType,
Tensor,
)
from tritonserver._api._datautils import (
STRING_TO_TRITON_MEMORY_TYPE,
TRITON_TO_NUMPY_DTYPE,
)
from tritonserver._c.triton_bindings import (
TRITONSERVER_DataTypeString as DataTypeString,
)
from tritonserver._c.triton_bindings import (
TRITONSERVER_MemoryTypeString as MemoryTypeString,
)
from tritonserver._c.triton_bindings import (
TRITONSERVER_StringToDataType as StringToDataType,
)
class DataPlaneError(Exception):
pass
ICP_TENSOR_URI = "icp_tensor_uri"
ICP_MEMORY_TYPE = "icp_memory_type"
ICP_MEMORY_TYPE_ID = "icp_memory_type_id"
ICP_TENSOR_SIZE = "icp_tensor_size"
def set_icp_shape(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: Sequence[int],
) -> None:
for dim in value:
message.shape.append(dim)
def get_icp_shape(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> Sequence[int]:
return message.shape
def set_icp_data_type(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: DataType,
) -> None:
message.datatype = DataTypeString(value)
def get_icp_data_type(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> DataType:
return StringToDataType(message.datatype)
def set_icp_tensor_uri(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: str,
) -> None:
message.parameters[ICP_TENSOR_URI].string_param = value
def get_icp_tensor_uri(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> str | None:
if ICP_TENSOR_URI not in message.parameters:
return None
return message.parameters[ICP_TENSOR_URI].string_param
def set_icp_tensor_size(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: int,
) -> None:
message.parameters[ICP_TENSOR_SIZE].uint64_param = value
def get_icp_tensor_size(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> int | None:
if ICP_TENSOR_SIZE not in message.parameters:
return None
return message.parameters[ICP_TENSOR_SIZE].uint64_param
def set_icp_memory_type(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: MemoryType,
) -> None:
message.parameters[ICP_MEMORY_TYPE].string_param = MemoryTypeString(value)
def get_icp_memory_type(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> MemoryType | None:
if ICP_MEMORY_TYPE not in message.parameters:
return None
return STRING_TO_TRITON_MEMORY_TYPE[
message.parameters[ICP_MEMORY_TYPE].string_param
]
def set_icp_memory_type_id(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: int,
) -> None:
message.parameters[ICP_MEMORY_TYPE_ID].int64_param = value
def get_icp_memory_type_id(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> int | None:
if ICP_MEMORY_TYPE_ID not in message.parameters:
return None
return message.parameters[ICP_MEMORY_TYPE_ID].int64_param
def set_icp_tensor_contents(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
tensor: Tensor,
) -> None:
set_icp_memory_type(message, MemoryType.CPU)
set_icp_memory_type_id(message, 0)
set_icp_tensor_size(message, tensor.size)
if tensor.data_type == DataType.BYTES:
array = tensor.to_bytes_array()
for i in list(array.flat):
message.contents.bytes_contents.append(i)
else:
if tensor.memory_type == MemoryType.CPU:
# Directly use the memory buffer when contents on the CPU.
array = tensor.memory_buffer.owner
elif tensor.memory_type == MemoryType.GPU:
with cupy.cuda.Device(tensor.memory_buffer.memory_type_id):
array = cupy.from_dlpack(tensor)
else:
raise InvalidArgumentError(
f"Invalid Tensor Memory Type {tensor.memory_type}"
)
message.contents.bytes_contents.append(array.tobytes())
def get_icp_tensor_contents(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> Tensor | None:
if not message.HasField("contents"):
# Return None if the content is not part of message
return None
datatype = get_icp_data_type(message)
shape = get_icp_shape(message)
tensor = None
if datatype == DataType.BYTES:
array = numpy.array(
[
message.contents.bytes_contents[i]
for i in range(len(message.contents.bytes_contents))
]
)
array = numpy.reshape(array, shape)
tensor = Tensor.from_bytes_array(array)
else:
array = numpy.array(
numpy.frombuffer(
message.contents.bytes_contents[0],
dtype=TRITON_TO_NUMPY_DTYPE[datatype],
)
)
tensor = Tensor(datatype, shape, MemoryBuffer.from_dlpack(array))
return tensor
class DataPlane(abc.ABC):
def __init__(self) -> None:
pass
@abc.abstractmethod
def connect(self) -> None:
pass
@abc.abstractmethod
def put_input_tensor(
self, tensor: Tensor, tensor_id: Optional[uuid.UUID], use_tensor_contents: bool
) -> ModelInferRequest.InferInputTensor:
pass
@abc.abstractmethod
def put_output_tensor(
self, tensor: Tensor, tensor_id: Optional[uuid.UUID], use_tensor_contents: bool
) -> ModelInferResponse.InferOutputTensor:
pass
@abc.abstractmethod
def get_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
requested_memory_type: Optional[MemoryType] = None,
requested_memory_type_id: Optional[int] = None,
) -> Tensor:
pass
@abc.abstractmethod
def create_input_tensor_reference(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> ModelInferRequest.InferInputTensor:
pass
@abc.abstractmethod
def create_output_tensor_reference(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> ModelInferResponse.InferOutputTensor:
pass
@abc.abstractmethod
def release_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> None:
pass
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import os
import shutil
import subprocess
import uuid
from collections.abc import AsyncIterator, Awaitable, Callable
from typing import Dict, Optional
from urllib.parse import urlsplit, urlunsplit
import nats
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from triton_distributed.icp.request_plane import (
RequestPlane,
get_icp_final_response,
get_icp_request_id,
get_icp_response_error,
get_icp_response_to_uri,
set_icp_component_id,
set_icp_request_id,
set_icp_request_to_uri,
set_icp_response_to_uri,
)
from tritonserver import InvalidArgumentError
class AsyncModelInferRequestIterator:
def __init__(self, requests: list[ModelInferRequest]) -> None:
self._requests = requests
def __aiter__(self) -> AsyncModelInferRequestIterator:
return self
async def __anext__(self):
if not self._requests:
raise StopAsyncIteration
return self._requests.pop(0)
class AsyncModelInferResponseIterator:
def __init__(
self,
queue: Optional[asyncio.Queue],
raise_on_error=False,
) -> None:
self._queue = queue
self._complete = False
self._raise_on_error = raise_on_error
if not self._queue:
self._complete = True
def __aiter__(self) -> AsyncModelInferResponseIterator:
return self
async def __anext__(self):
if self._complete or self._queue is None:
raise StopAsyncIteration
response = await self._queue.get()
self._complete = get_icp_final_response(response)
error = get_icp_response_error(response)
if error is not None and self._raise_on_error:
raise error
return response
def cancel(self) -> None:
raise NotImplementedError()
class NatsServer:
def __init__(
self,
port: int = 4223,
store_dir: str = "/tmp/nats_store",
log_dir: str = "logs",
debug: bool = False,
clear_store: bool = True,
dry_run: bool = False,
) -> None:
self._process = None
self.port = port
self.url = f"nats://localhost:{port}"
command = [
"/usr/local/bin/nats-server",
"--jetstream",
"--port",
str(port),
"--store_dir",
store_dir,
]
if debug:
command.extend(["--debug", "--trace"])
if dry_run:
print(command)
return
os.makedirs(log_dir, exist_ok=True)
if clear_store:
shutil.rmtree(store_dir, ignore_errors=True)
with open(f"{log_dir}/nats_server.stdout.log", "wt") as output_:
with open(f"{log_dir}/nats_server.stderr.log", "wt") as output_err:
process = subprocess.Popen(
command,
stdin=subprocess.DEVNULL,
stdout=output_,
stderr=output_err,
)
self._process = process
def __del__(self):
if self._process:
self._process.terminate()
self._process.kill()
self._process.wait()
class NatsRequestPlane(RequestPlane):
@property
def component_id(self):
return self._component_id
@property
def response_uri(self):
return self._response_uri
async def close(self):
if self._nats_client:
await self._nats_client.close()
def __del__(self):
if self._event_loop and not self._event_loop.is_closed():
self._event_loop.run_until_complete(self.close())
def __init__(
self,
request_plane_uri: str = "nats://localhost:4222",
component_id: Optional[uuid.UUID] = None,
) -> None:
self._request_plane_uri = request_plane_uri
self._component_id = component_id if component_id else uuid.uuid1()
self._response_stream_name = f"component-{self._component_id}-response"
split_uri = urlsplit(self._request_plane_uri)._asdict()
split_uri["path"] = self._response_stream_name
self._response_uri = str(urlunsplit(split_uri.values()))
self._model_streams: Dict[
tuple[str, str], # model_name, model_version
tuple[
str, # stream_name
Optional[nats.js.JetStreamContext.PullSubscription], # general requests
Optional[nats.js.JetStreamContext.PullSubscription], # direct requests
],
] = {}
self._posted_requests: Dict[
uuid.UUID, # request id
tuple[
Optional[asyncio.Queue], # response queue
Optional[Callable[[ModelInferResponse], None | Awaitable[None]]],
Optional[Callable[[ModelInferResponse], Awaitable[None]]],
],
] = {}
self._jet_stream: Optional[nats.js.JetStreamContext] = None
self._event_loop: Optional[asyncio.AbstractEventLoop] = None
def _replace_special_chars(self, stream_name):
return stream_name.replace(".", "-")
async def _get_model_stream(
self, model_name: str, model_version: str, subscribe: bool
) -> tuple[
str,
Optional[nats.js.JetStreamContext.PullSubscription],
Optional[nats.js.JetStreamContext.PullSubscription],
]:
if self._jet_stream is None:
raise InvalidArgumentError("Not Connected!")
if (model_name, model_version) in self._model_streams:
return self._model_streams[(model_name, model_version)]
model_stream_name = self._replace_special_chars(
f"model-{model_name}-{model_version}"
)
await self._jet_stream.add_stream(
name=model_stream_name,
subjects=[model_stream_name, model_stream_name + ".*"],
retention=nats.js.api.RetentionPolicy.WORK_QUEUE,
)
general_requests = None
directed_requests = None
if subscribe:
general_requests = await self._jet_stream.pull_subscribe(
subject=model_stream_name,
stream=model_stream_name,
durable=model_stream_name,
)
directed_subject = f"{model_stream_name}.{self._component_id}"
directed_durable = f"{model_stream_name}-{self._component_id}"
directed_requests = await self._jet_stream.pull_subscribe(
subject=directed_subject,
stream=model_stream_name,
durable=directed_durable,
)
return self._model_streams.setdefault(
(model_name, model_version),
(model_stream_name, general_requests, directed_requests),
)
async def _response_callback(self, message):
await message.ack()
response = ModelInferResponse()
response.ParseFromString(message.data)
request_id = get_icp_request_id(response)
if request_id in self._posted_requests:
response_queue, handler, async_handler = self._posted_requests[request_id]
if get_icp_final_response(response):
del self._posted_requests[request_id]
if response_queue:
return await response_queue.put(response)
if async_handler is not None:
return await async_handler(response)
if handler is not None:
return handler(response)
async def connect(self):
self._nats_client = await nats.connect(self._request_plane_uri)
self._jet_stream = self._nats_client.jetstream()
self._event_loop = asyncio.get_event_loop()
await self._jet_stream.add_stream(
name=self._response_stream_name,
subjects=[self._response_stream_name],
retention=nats.js.api.RetentionPolicy.WORK_QUEUE,
)
await self._jet_stream.subscribe(
self._response_stream_name,
cb=self._response_callback,
durable=self._response_stream_name,
stream=self._response_stream_name,
)
async def pull_requests(
self,
model_name: str,
model_version: str,
number_requests: int = 1,
timeout: Optional[float] = None,
) -> AsyncIterator[ModelInferRequest]:
# Note directed requests and general requests are
# pulled in parallel. Directed requests are consumed
# first. If there are more requests than the batch size
# then extra requests are scheduled for redlivery via nak
requests: list[ModelInferRequest] = []
acks = []
_, general, directed = await self._get_model_stream(
model_name, model_version, subscribe=True
)
tasks = [
asyncio.create_task(
subscription.fetch(batch=number_requests, timeout=timeout)
)
for subscription in [directed, general]
if subscription
]
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
for task in tasks:
if task not in done:
continue
try:
for message in task.result():
if len(requests) < number_requests:
request = ModelInferRequest()
request.ParseFromString(message.data)
requests.append(request)
acks.append(message.ack())
else:
acks.append(message.nak())
except nats.errors.TimeoutError:
continue
asyncio.gather(*acks)
return AsyncModelInferRequestIterator(requests)
@staticmethod
async def _single_response(response: ModelInferResponse):
yield response
async def post_response(
self,
request: ModelInferRequest,
responses: AsyncIterator[ModelInferResponse] | ModelInferResponse,
):
if self._jet_stream is None:
raise InvalidArgumentError("Not Connected!")
request_id = get_icp_request_id(request)
if request_id is None:
raise InvalidArgumentError("ICP request must have request id")
response_to_uri = get_icp_response_to_uri(request)
if not response_to_uri:
raise InvalidArgumentError(
"Attempting to send a response when non requested"
)
parsed = urlsplit(response_to_uri)
response_stream = parsed.path.replace("/", "")
if isinstance(responses, ModelInferResponse):
responses = NatsRequestPlane._single_response(responses)
async for response in responses:
set_icp_request_id(response, request_id)
response.model_name = request.model_name
response.model_version = request.model_version
response.id = request.id
set_icp_component_id(response, self._component_id)
await self._jet_stream.publish(
response_stream,
response.SerializeToString(),
stream=response_stream,
)
async def post_request(
self,
request: ModelInferRequest,
*,
component_id: Optional[uuid.UUID] = None,
response_iterator: bool = False,
response_handler: Optional[
Callable[[ModelInferResponse], None | Awaitable[None]]
] = None,
) -> AsyncIterator[ModelInferResponse]:
if self._jet_stream is None:
raise InvalidArgumentError("Not Connected!")
if response_iterator and response_handler:
raise InvalidArgumentError(
"Can only specify either response handler or response iterator"
)
async_response_handler = None
response_queue = None
if response_handler or response_iterator:
request_id = get_icp_request_id(request)
if request_id is None:
request_id = uuid.uuid1()
set_icp_request_id(request, request_id)
set_icp_response_to_uri(request, self._response_uri)
set_icp_component_id(request, self._component_id)
async_response_handler = (
response_handler
if asyncio.iscoroutinefunction(response_handler)
else None
)
response_queue = None
if response_iterator:
response_queue = asyncio.Queue()
self._posted_requests[request_id] = (
response_queue,
response_handler,
async_response_handler,
)
stream_name, _, _ = await self._get_model_stream(
request.model_name, request.model_version, subscribe=False
)
subject = stream_name
if component_id:
subject += f".{component_id}"
split_uri = urlsplit(self._request_plane_uri)._asdict()
split_uri["path"] = subject
set_icp_request_to_uri(request, str(urlunsplit(split_uri.values())))
await self._jet_stream.publish(
subject,
request.SerializeToString(),
stream=stream_name,
)
return AsyncModelInferResponseIterator(response_queue)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract Class for interacting with the Triton Inference Serving Platform Inter-Component Protocol Control Plane"""
import abc
import uuid
from typing import AsyncIterator, Awaitable, Callable, Optional
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from tritonserver import TritonError
ICP_REQUEST_ID = "icp_request_id"
ICP_FINAL_RESPONSE = "icp_final_response"
ICP_RESPONSE_FROM_URI = "icp_response_from_uri"
ICP_COMPONENT_ID = "icp_component_id"
ICP_RESPONSE_TO_URI = "icp_response_to_uri"
ICP_REQUEST_TO_URI = "icp_request_to_uri"
ICP_REQUEST_CANCELLED = "icp_request_cancelled"
ICP_ERROR = "icp_response_error"
def get_icp_request_id(
message: ModelInferRequest | ModelInferResponse,
) -> uuid.UUID | None:
if ICP_REQUEST_ID not in message.parameters:
return None
return uuid.UUID(message.parameters[ICP_REQUEST_ID].string_param)
def set_icp_request_id(
message: ModelInferRequest | ModelInferResponse, value: uuid.UUID
) -> None:
message.parameters[ICP_REQUEST_ID].string_param = str(value)
def get_icp_response_error(message: ModelInferResponse) -> TritonError | None:
if ICP_ERROR not in message.parameters:
return None
return TritonError(message.parameters[ICP_ERROR].string_param)
def set_icp_response_error(message: ModelInferResponse, value: TritonError) -> None:
message.parameters[ICP_ERROR].string_param = str(value)
def get_icp_final_response(
message: ModelInferResponse,
) -> bool:
if ICP_FINAL_RESPONSE not in message.parameters:
return False
return message.parameters[ICP_FINAL_RESPONSE].bool_param
def set_icp_final_response(message: ModelInferResponse, value: bool) -> None:
message.parameters[ICP_FINAL_RESPONSE].bool_param = value
def get_icp_response_to_uri(message: ModelInferRequest) -> str | None:
if ICP_RESPONSE_TO_URI not in message.parameters:
return None
return message.parameters[ICP_RESPONSE_TO_URI].string_param
def get_icp_component_id(
message: ModelInferRequest | ModelInferResponse,
) -> uuid.UUID | None:
if ICP_COMPONENT_ID not in message.parameters:
return None
return uuid.UUID(message.parameters[ICP_COMPONENT_ID].string_param)
def set_icp_component_id(
message: ModelInferRequest | ModelInferResponse, value: uuid.UUID
) -> None:
message.parameters[ICP_COMPONENT_ID].string_param = str(value)
def set_icp_response_to_uri(message: ModelInferRequest, value: str) -> None:
message.parameters[ICP_RESPONSE_TO_URI].string_param = value
def get_icp_request_to_uri(message: ModelInferRequest) -> str | None:
if ICP_REQUEST_TO_URI not in message.parameters:
return None
return message.parameters[ICP_REQUEST_TO_URI].string_param
def set_icp_request_to_uri(message: ModelInferRequest, value: str) -> None:
message.parameters[ICP_REQUEST_TO_URI].string_param = value
def get_icp_response_from_uri(message: ModelInferResponse) -> str | None:
if ICP_RESPONSE_FROM_URI not in message.parameters:
return None
return message.parameters[ICP_RESPONSE_FROM_URI].string_param
def set_icp_response_from_uri(message: ModelInferResponse, value: str) -> None:
message.parameters[ICP_RESPONSE_FROM_URI].string_param = value
class RequestPlane(abc.ABC):
@property
@abc.abstractmethod
def component_id(self) -> uuid.UUID:
pass
@abc.abstractmethod
async def connect(self) -> None:
pass
@abc.abstractmethod
async def pull_requests(
self,
model_name: str,
model_version: str,
number_requests: int = 1,
timeout: Optional[float] = None,
) -> AsyncIterator[ModelInferRequest]:
pass
@abc.abstractmethod
async def post_response(
self,
request: ModelInferRequest,
responses: AsyncIterator[ModelInferResponse] | ModelInferResponse,
) -> None:
pass
@abc.abstractmethod
async def post_request(
self,
request: ModelInferRequest,
*,
component_id: Optional[uuid.UUID] = None,
response_iterator: bool = True,
response_handler: Optional[
Callable[[ModelInferResponse], None | Awaitable[None]]
] = None,
) -> AsyncIterator[ModelInferResponse]:
pass
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import contextlib
import logging
import threading
import uuid
from enum import IntEnum, auto
from functools import cached_property
from typing import Dict, Optional, Tuple
from urllib.parse import urlsplit
import cupy
import numpy
import tritonserver
import ucp
from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp.data_plane import (
DataPlane,
DataPlaneError,
get_icp_data_type,
get_icp_memory_type,
get_icp_memory_type_id,
get_icp_shape,
get_icp_tensor_contents,
get_icp_tensor_size,
get_icp_tensor_uri,
set_icp_data_type,
set_icp_memory_type,
set_icp_memory_type_id,
set_icp_shape,
set_icp_tensor_contents,
set_icp_tensor_size,
set_icp_tensor_uri,
)
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from tritonserver import InvalidArgumentError, MemoryBuffer, MemoryType, Tensor
LOGGER = logging.getLogger(__name__)
class UCP_DATA_PLANE_COMMANDS(IntEnum):
GET = auto()
CREATE_REFERENCE = auto()
RELEASE = auto()
# UCP has deadlocks when created multiple instances in a single process
# Create a singleton
_ucp_data_plane_singleton = None
def UcpDataPlane(
hostname: Optional[str] = None, port: int = 0, keep_endpoints_open: bool = False
):
global _ucp_data_plane_singleton
if _ucp_data_plane_singleton is None:
_ucp_data_plane_singleton = _UcpDataPlane(hostname, port, keep_endpoints_open)
return _ucp_data_plane_singleton
class _UcpDataPlane(DataPlane):
def __init__(
self,
hostname: Optional[str] = None,
port: int = 0,
keep_endpoints_open: bool = False,
) -> None:
self._tensor_store: Dict[uuid.UUID, Tensor] = {}
self._id_size = len(uuid.uuid1().bytes)
self._port = port
self._hostname = hostname or ucp.get_address()
self._event_loop_thread = threading.Thread(
target=self._run_event_loop, daemon=True
)
self._start_event = threading.Event()
self._listener: Optional[ucp.Listener] = None
self._event_loop: Optional[asyncio.AbstractEventLoop] = None
self._closed = False
self._keep_endpoints_open = keep_endpoints_open
self._endpoints: Dict[Tuple[str, int], ucp.Endpoint] = {}
LOGGER.debug(
"Creating UCP data plane with keep_endpoints_open=%s", keep_endpoints_open
)
@cached_property
def _cuda_is_available(self):
# Note: cuda.is_avalailable initializes cuda
# and can't be called when forking subprocesses
# care should be taken to only call it within
# subprocesses or use 'spawn'
try:
return cupy.cuda.is_available()
except CUDARuntimeError:
return False
@property
def hostname(self) -> str:
return self._hostname
@property
def port(self) -> int:
return self._port
def connect(self) -> None:
if self._event_loop is None:
self._event_loop_thread.start()
self._start_event.wait()
if self._listener is None or self._listener.closed():
raise DataPlaneError("Unable to start data plane")
async def _close(self, wait_for_release=0):
self._closed = True
if self._listener is not None:
if wait_for_release:
while self._tensor_store and wait_for_release:
await asyncio.sleep(1)
wait_for_release -= 1
self._listener.close()
self._listener = None
def close(self, wait_for_release=0):
if self._event_loop is None:
return
if self._event_loop.is_closed():
return
asyncio.run_coroutine_threadsafe(
self._close(wait_for_release),
self._event_loop,
).result()
def __del__(self):
self.close()
def _run_event_loop(self):
asyncio.run(self._serve())
async def _serve(self):
self._event_loop = asyncio.get_running_loop()
try:
self._listener = ucp.create_listener(self._send_receive, self._port)
self._port = self._listener.port
self._start_event.set()
except Exception:
self._listener = None
self._start_event.set()
while self._listener is not None and not self._listener.closed():
await asyncio.sleep(1)
async def _send_receive(self, ep):
while True:
tensor_id_bytes = numpy.empty(self._id_size, dtype="u1")
await ep.recv(tensor_id_bytes)
tensor_id = uuid.UUID(bytes=tensor_id_bytes.tobytes())
command = numpy.empty(1, dtype="u1")
await ep.recv(command)
if command == UCP_DATA_PLANE_COMMANDS.GET:
if tensor_id in self._tensor_store:
tensor = self._tensor_store[tensor_id]
array_module = numpy
if tensor.memory_type == tritonserver.MemoryType.CPU:
array_module = numpy
device_manager = contextlib.nullcontext()
elif tensor.memory_type == tritonserver.MemoryType.GPU:
array_module = cupy
device_manager = cupy.cuda.Device(
tensor.memory_buffer.memory_type_id
)
else:
raise InvalidArgumentError(
f"Invalid Memory Type {tensor.memory_type}"
)
with device_manager:
if tensor.data_type == tritonserver.DataType.BYTES:
array = tensor.memory_buffer.owner
else:
array = array_module.from_dlpack(tensor)
await ep.send(array)
elif command == UCP_DATA_PLANE_COMMANDS.CREATE_REFERENCE:
if tensor_id in self._tensor_store:
reference_tensor_id = uuid.uuid1()
self._tensor_store[reference_tensor_id] = self._tensor_store[
tensor_id
]
await ep.send(numpy.array(reference_tensor_id.bytes))
elif command == UCP_DATA_PLANE_COMMANDS.RELEASE:
if tensor_id in self._tensor_store:
del self._tensor_store[tensor_id]
await ep.send(numpy.array(tensor_id.bytes))
if not self._keep_endpoints_open:
break
await ep.close()
def _put_tensor(
self,
result: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
tensor: Tensor,
tensor_id: Optional[uuid.UUID] = None,
use_tensor_contents: bool = False,
):
if self._closed:
raise DataPlaneError("Adding tensor after close")
set_icp_data_type(result, tensor.data_type)
set_icp_shape(result, tensor.shape)
if use_tensor_contents:
set_icp_tensor_contents(result, tensor)
else:
if tensor_id is None:
tensor_id = uuid.uuid1()
self._tensor_store[tensor_id] = tensor
tensor_uri = f"ucp://{self._hostname}:{self._port}/{tensor_id}"
set_icp_tensor_uri(result, tensor_uri)
set_icp_memory_type(result, tensor.memory_buffer.memory_type)
set_icp_memory_type_id(result, tensor.memory_buffer.memory_type_id)
set_icp_tensor_size(result, tensor.size)
def put_input_tensor(
self,
tensor: Tensor,
tensor_id: Optional[uuid.UUID] = None,
use_tensor_contents: bool = False,
) -> ModelInferRequest.InferInputTensor:
"""Put an input tensor into the data plane or within
returned ModelInferRequest.InferInputTensor itself.
Args:
tensor: The tensor to put.
tensor_id: The id of the tensor to put.
If not provided, a new id will be generated.
use_tensor_contents: when True, tensor data will be
added directly to ModelInferRequest.InferInputTensor
contents field; otherwise tensor data will be sent
separately on the data plane.
Returns:
ModelInferRequest.InferInputTensor object.
"""
result = ModelInferRequest.InferInputTensor()
self._put_tensor(
result,
tensor,
tensor_id=tensor_id,
use_tensor_contents=use_tensor_contents,
)
return result
def put_output_tensor(
self,
tensor: Tensor,
tensor_id: Optional[uuid.UUID] = None,
use_tensor_contents: bool = False,
) -> ModelInferResponse.InferOutputTensor:
"""Put an output tensor into the data plane or within
returned ModelInferResponse.InferOutputTensor itself.
Args:
tensor: The tensor to put.
tensor_id: The id of the tensor to put.
If not provided, a new id will be generated.
use_tensor_contents: when True, tensor data will be
added directly to ModelInferResponse.InferInputTensor
contents field; otherwise tensor data will be sent
separately on the data plane.
Returns:
ModelInferResponse.InferOutputTensor object.
"""
result = ModelInferResponse.InferOutputTensor()
self._put_tensor(
result,
tensor,
tensor_id=tensor_id,
use_tensor_contents=use_tensor_contents,
)
return result
def _split_tensor_uri(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> tuple[uuid.UUID, str, int]:
tensor_uri = get_icp_tensor_uri(remote_tensor)
split_uri = urlsplit(tensor_uri)
path = str(split_uri.path).replace("/", "")
tensor_id = uuid.UUID(path)
host = split_uri.hostname
port = split_uri.port
if host is None or not isinstance(host, str):
raise DataPlaneError(f"Invalid host {host}")
if port is None:
raise DataPlaneError(f"Invalid Port {port}")
return tensor_id, host, port
async def _get_remote_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
requested_memory_type: Optional[MemoryType],
requested_memory_type_id: Optional[int],
) -> Tensor:
tensor_contents = get_icp_tensor_contents(remote_tensor)
if tensor_contents is not None:
return tensor_contents
tensor_size = get_icp_tensor_size(remote_tensor)
memory_type = get_icp_memory_type(remote_tensor)
data_type = get_icp_data_type(remote_tensor)
shape = get_icp_shape(remote_tensor)
tensor_id, host, port = self._split_tensor_uri(remote_tensor)
storage = None
if tensor_size is None:
raise DataPlaneError("tensor size can not be none")
if requested_memory_type is not None:
memory_type = requested_memory_type
if memory_type == tritonserver.MemoryType.GPU and self._cuda_is_available:
array_module = cupy
if requested_memory_type_id is not None:
device_manager = cupy.cuda.Device(requested_memory_type_id)
else:
device_manager = contextlib.nullcontext()
else:
array_module = numpy
device_manager = contextlib.nullcontext()
with device_manager:
storage = array_module.empty(tensor_size, dtype="u1")
try:
endpoint = await self._create_endpoint(host, port)
await endpoint.send(numpy.array(tensor_id.bytes))
await endpoint.send(
numpy.array(UCP_DATA_PLANE_COMMANDS.GET, dtype="u1")
)
await endpoint.recv(storage)
if not self._keep_endpoints_open:
await self._close_endpoint(host, port)
return Tensor(data_type, shape, MemoryBuffer.from_dlpack(storage))
except Exception as e:
raise DataPlaneError(f"Error Getting Tensor:\n{remote_tensor}") from e
async def _create_remote_tensor_reference(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
result: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
):
tensor_size = get_icp_tensor_size(remote_tensor)
memory_type = get_icp_memory_type(remote_tensor)
memory_type_id = get_icp_memory_type_id(remote_tensor)
if tensor_size is None or memory_type is None or memory_type_id is None:
raise DataPlaneError("tensor size and memory type must not be none")
set_icp_shape(result, get_icp_shape(remote_tensor))
set_icp_data_type(result, get_icp_data_type(remote_tensor))
set_icp_tensor_size(result, tensor_size)
set_icp_memory_type(result, memory_type)
set_icp_memory_type_id(result, memory_type_id)
if remote_tensor.HasField("contents"):
for value in remote_tensor.contents.bytes_contents:
result.contents.bytes_contents.append(value)
return
tensor_id, host, port = self._split_tensor_uri(remote_tensor)
try:
endpoint = await self._create_endpoint(host, port)
await endpoint.send(numpy.array(tensor_id.bytes))
await endpoint.send(
numpy.array(UCP_DATA_PLANE_COMMANDS.CREATE_REFERENCE, dtype="u1")
)
reference_tensor_id_bytes = numpy.empty(self._id_size, dtype="u1")
await endpoint.recv(reference_tensor_id_bytes)
if not self._keep_endpoints_open:
await self._close_endpoint(host, port)
reference_tensor_id = uuid.UUID(bytes=reference_tensor_id_bytes.tobytes())
set_icp_tensor_uri(result, f"ucp://{host}:{port}/{reference_tensor_id}")
except Exception as e:
raise DataPlaneError("Error Referencing Tensor:\n{remote_tensor}") from e
async def _release_remote_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
):
tensor_id, host, port = self._split_tensor_uri(remote_tensor)
try:
endpoint = await self._create_endpoint(host, port)
await endpoint.send(numpy.array(tensor_id.bytes))
await endpoint.send(
numpy.array(UCP_DATA_PLANE_COMMANDS.RELEASE, dtype="u1")
)
ack_tensor_id = numpy.empty(self._id_size, dtype="u1")
await endpoint.recv(ack_tensor_id)
if not self._keep_endpoints_open:
await self._close_endpoint(host, port)
except Exception as e:
raise DataPlaneError(f"Error Releasing Tensor:\n{remote_tensor}") from e
def get_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
requested_memory_type: Optional[MemoryType] = None,
requested_memory_type_id: Optional[int] = None,
) -> Tensor:
if self._event_loop is None:
raise DataPlaneError("Not connected")
return asyncio.run_coroutine_threadsafe(
self._get_remote_tensor(
remote_tensor, requested_memory_type, requested_memory_type_id
),
self._event_loop,
).result()
def create_input_tensor_reference(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> ModelInferRequest.InferInputTensor:
if self._event_loop is None:
raise DataPlaneError("Not connected")
result = ModelInferRequest.InferInputTensor()
asyncio.run_coroutine_threadsafe(
self._create_remote_tensor_reference(remote_tensor, result),
self._event_loop,
).result()
return result
def create_output_tensor_reference(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> ModelInferResponse.InferOutputTensor:
if self._event_loop is None:
raise DataPlaneError("Not connected")
result = ModelInferResponse.InferOutputTensor()
asyncio.run_coroutine_threadsafe(
self._create_remote_tensor_reference(remote_tensor, result),
self._event_loop,
).result()
return result
def release_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> None:
if remote_tensor.HasField("contents"):
return None
if self._event_loop is None:
raise DataPlaneError("Not connected")
return asyncio.run_coroutine_threadsafe(
self._release_remote_tensor(remote_tensor), self._event_loop
).result()
async def _create_endpoint(self, host: str, port: int):
endpoint = self._endpoints.get((host, port))
if endpoint is None:
LOGGER.debug(f"Creating endpoint for {host}:{port}")
endpoint = await ucp.create_endpoint(host, port)
self._endpoints[(host, port)] = endpoint
else:
LOGGER.debug(f"Reusing endpoint for {host}:{port}")
return endpoint
async def _close_endpoint(self, host: str, port: int):
endpoint = self._endpoints.pop((host, port), None)
if endpoint is not None:
LOGGER.debug(f"Closing endpoint for {host}:{port}")
await endpoint.close()
else:
LOGGER.debug(f"Endpoint for {host}:{port} not found")
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import uuid
from multiprocessing import Process, Queue
from typing import Sequence
import cupy
import numpy
import pytest
import ucp
from cupy_backends.cuda.api.runtime import CUDARuntimeError
from triton_distributed.icp.data_plane import DataPlaneError
from triton_distributed.icp.ucp_data_plane import (
UcpDataPlane,
get_icp_tensor_uri,
set_icp_tensor_uri,
)
from tritonserver import DataType, MemoryType, Tensor
from tritonserver._api._datautils import TRITON_TO_NUMPY_DTYPE
# TODO decide if some tests should be removed
# from pre_merge
pytestmark = pytest.mark.pre_merge
def _cuda_available():
# Note: cuda.is_avalailable initializes cuda
# and can't be called when forking subprocesses
try:
return cupy.cuda.is_available()
except CUDARuntimeError:
return False
def data_plane_reader(
input_tensor_queue: Queue,
tensor_descriptor_queue: Queue,
output_tensor_queue: Queue,
memory_type: MemoryType,
memory_type_id: int,
):
ucp.reset()
data_plane = UcpDataPlane()
data_plane.connect()
output_tensor = None
get_error = None
release_error = None
while True:
input_tensor = tensor_descriptor_queue.get()
if input_tensor is None:
break
try:
output_tensor = data_plane.get_tensor(
input_tensor,
requested_memory_type=memory_type,
requested_memory_type_id=memory_type_id,
)
if output_tensor.data_type == DataType.BYTES:
output_tensor = output_tensor.to_bytes_array()
else:
if memory_type == MemoryType.GPU and _cuda_available():
output_tensor = cupy.from_dlpack(output_tensor)
else:
output_tensor = numpy.from_dlpack(output_tensor)
except DataPlaneError as e:
get_error = e
try:
data_plane.release_tensor(input_tensor)
except DataPlaneError as e:
release_error = e
if get_error:
output_tensor_queue.put((get_error, release_error))
else:
output_tensor_queue.put((output_tensor, release_error))
output_tensor_queue.put((None, None))
data_plane.close()
def data_plane_writer(
input_tensor_queue: Queue,
tensor_descriptor_queue: Queue,
output_tensor_queue: Queue,
memory_type: MemoryType,
memory_type_id: int,
use_invalid_descriptor: bool = False,
timeout=30,
use_tensor_contents: bool = False,
):
ucp.reset()
data_plane = UcpDataPlane()
data_plane.connect()
while True:
input_tensor = input_tensor_queue.get()
if input_tensor is None:
tensor_descriptor_queue.put(None)
break
input_tensor = Tensor._from_object(input_tensor)
input_tensor_descriptor = data_plane.put_input_tensor(
input_tensor, use_tensor_contents=use_tensor_contents
)
if use_invalid_descriptor and not use_tensor_contents:
tensor_uri = get_icp_tensor_uri(input_tensor_descriptor)
invalid_tensor_id = str(uuid.uuid1())
tensor_uri = tensor_uri[: -len(invalid_tensor_id)]
tensor_uri = tensor_uri + invalid_tensor_id
set_icp_tensor_uri(input_tensor_descriptor, tensor_uri)
tensor_descriptor_queue.put(input_tensor_descriptor)
if not use_invalid_descriptor and timeout:
data_plane.close(wait_for_release=timeout)
data_plane.close()
@pytest.fixture
def tensors():
tensors = [
numpy.random.randint(0, 10, size=(2, 3)),
numpy.random.randint(0, 10, size=(100)),
numpy.random.randint(2, size=(1), dtype=bool),
]
return tensors
@pytest.mark.timeout(60, method="thread")
def test_data_plane_error_invalid_tensor_uri(request):
input_tensor_queue: Queue = Queue()
tensor_descriptor_queue: Queue = Queue()
output_tensor_queue: Queue = Queue()
input_tensors = []
memory_type = MemoryType.CPU
memory_type_id = 0
reader = Process(
target=data_plane_reader,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
),
)
writer = Process(
target=data_plane_writer,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
True,
30,
),
)
reader.start()
writer.start()
tensors = request.getfixturevalue("tensors")
for tensor in tensors:
input_tensors.append(Tensor.from_dlpack(tensor))
for input_tensor in input_tensors:
if input_tensor.memory_type == MemoryType.CPU or not _cuda_available():
input_tensor_queue.put(numpy.from_dlpack(input_tensor))
else:
input_tensor_queue.put(cupy.from_dlpack(input_tensor))
input_tensor_queue.put(None)
reader.join()
writer.join()
while True:
output_tensor, release_error = output_tensor_queue.get()
if output_tensor is None:
break
assert isinstance(output_tensor, DataPlaneError)
assert isinstance(release_error, DataPlaneError)
@pytest.mark.timeout(30, method="thread")
@pytest.mark.parametrize(
"memory_type,memory_type_id", [(MemoryType.CPU, 0), (MemoryType.GPU, 0)]
)
def test_requested_memory_type(memory_type, memory_type_id, request):
ctx = multiprocessing.get_context("spawn")
input_tensor_queue = ctx.Queue()
tensor_descriptor_queue = ctx.Queue()
output_tensor_queue = ctx.Queue()
input_tensors = []
output_tensors = []
reader = ctx.Process(
target=data_plane_reader,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
),
)
writer = ctx.Process(
target=data_plane_writer,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
False,
30,
),
)
reader.start()
writer.start()
tensors = request.getfixturevalue("tensors")
for tensor in tensors:
input_tensors.append(Tensor.from_dlpack(tensor))
for input_tensor in input_tensors:
if input_tensor.memory_type == MemoryType.CPU or not _cuda_available():
input_tensor_queue.put(numpy.from_dlpack(input_tensor))
else:
input_tensor_queue.put(cupy.from_dlpack(input_tensor))
input_tensor_queue.put(None)
reader.join()
writer.join()
while True:
output_tensor, release_error = output_tensor_queue.get()
if output_tensor is None:
break
assert not isinstance(output_tensor, DataPlaneError)
output_tensors.append(Tensor.from_dlpack(output_tensor))
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
expected_memory_type = memory_type
if not _cuda_available():
expected_memory_type = MemoryType.CPU
assert output_tensor.memory_type == expected_memory_type
assert output_tensor.memory_type_id == memory_type_id
input_comparison = numpy.from_dlpack(input_tensor.to_host())
output_comparison = numpy.from_dlpack(output_tensor.to_host())
numpy.testing.assert_equal(input_comparison, output_comparison)
print(input_tensor, output_tensor)
def _get_random_tensor(data_type: DataType, size: Sequence[int]):
dtype = TRITON_TO_NUMPY_DTYPE[data_type]
value = numpy.random.rand(*size)
return value.astype(dtype)
@pytest.mark.timeout(30, method="thread")
@pytest.mark.parametrize(
"data_type",
[
data_type
for data_type in DataType.__members__.values()
if data_type not in [DataType.INVALID, DataType.BF16]
],
ids=[
data_type
for data_type in DataType.__members__.keys()
if data_type not in ["INVALID", "BF16"]
],
)
def test_tensor_types(request, data_type):
input_tensor_queue: Queue = Queue()
tensor_descriptor_queue: Queue = Queue()
output_tensor_queue: Queue = Queue()
input_tensors = []
output_tensors = []
memory_type = MemoryType.CPU
memory_type_id = 0
reader = Process(
target=data_plane_reader,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
),
)
writer = Process(
target=data_plane_writer,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
False,
30,
),
)
reader.start()
writer.start()
tensors = []
tensors.append(_get_random_tensor(data_type, [1, 4]))
for tensor in tensors:
if data_type == DataType.BYTES:
input_tensors.append(Tensor._from_object(tensor))
else:
input_tensors.append(Tensor.from_dlpack(tensor))
for input_tensor in input_tensors:
if input_tensor.data_type != DataType.BYTES:
input_tensor_queue.put(numpy.from_dlpack(input_tensor))
else:
input_tensor_queue.put(input_tensor.to_bytes_array())
input_tensor_queue.put(None)
reader.join()
writer.join()
while True:
output_tensor, release_error = output_tensor_queue.get()
if output_tensor is None:
break
assert not isinstance(output_tensor, DataPlaneError)
output_tensors.append(Tensor._from_object(output_tensor))
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
expected_memory_type = memory_type
if not _cuda_available():
expected_memory_type = MemoryType.CPU
assert output_tensor.memory_type == expected_memory_type
assert output_tensor.memory_type_id == memory_type_id
if input_tensor.data_type == DataType.BYTES:
input_comparison = input_tensor.to_bytes_array()
output_comparison = output_tensor.to_bytes_array()
else:
input_comparison = numpy.from_dlpack(input_tensor.to_host())
output_comparison = numpy.from_dlpack(output_tensor.to_host())
numpy.testing.assert_equal(input_comparison, output_comparison)
print(input_tensor, output_tensor)
@pytest.mark.timeout(30, method="thread")
@pytest.mark.parametrize(
"data_type",
[
data_type
for data_type in DataType.__members__.values()
if data_type not in [DataType.INVALID, DataType.BF16]
],
ids=[
data_type
for data_type in DataType.__members__.keys()
if data_type not in ["INVALID", "BF16"]
],
)
def test_use_tensor_contents(request, data_type):
input_tensor_queue: Queue = Queue()
tensor_descriptor_queue: Queue = Queue()
output_tensor_queue: Queue = Queue()
input_tensors = []
output_tensors = []
memory_type = MemoryType.CPU
memory_type_id = 0
reader = Process(
target=data_plane_reader,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
),
)
writer = Process(
target=data_plane_writer,
args=(
input_tensor_queue,
tensor_descriptor_queue,
output_tensor_queue,
memory_type,
memory_type_id,
True,
30,
True,
),
)
reader.start()
writer.start()
tensors = []
tensors.append(_get_random_tensor(data_type, [2, 4]))
for tensor in tensors:
if data_type == DataType.BYTES:
input_tensors.append(Tensor._from_object(tensor))
else:
input_tensors.append(Tensor.from_dlpack(tensor))
for input_tensor in input_tensors:
if input_tensor.data_type != DataType.BYTES:
input_tensor_queue.put(numpy.from_dlpack(input_tensor))
else:
input_tensor_queue.put(input_tensor.to_bytes_array())
input_tensor_queue.put(None)
reader.join()
writer.join()
while True:
output_tensor, release_error = output_tensor_queue.get()
if output_tensor is None:
break
assert not isinstance(output_tensor, DataPlaneError)
output_tensors.append(Tensor._from_object(output_tensor))
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
expected_memory_type = memory_type
if not _cuda_available():
expected_memory_type = MemoryType.CPU
assert output_tensor.memory_type == expected_memory_type
assert output_tensor.memory_type_id == memory_type_id
if input_tensor.data_type == DataType.BYTES:
input_comparison = input_tensor.to_bytes_array()
output_comparison = output_tensor.to_bytes_array()
else:
input_comparison = numpy.from_dlpack(input_tensor.to_host())
output_comparison = numpy.from_dlpack(output_tensor.to_host())
numpy.testing.assert_equal(input_comparison, output_comparison)
print(input_tensor, output_tensor)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import shutil
import subprocess
import time
import uuid
from multiprocessing import Process, Queue
import pytest
from triton_distributed.icp.nats_request_plane import NatsRequestPlane
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from triton_distributed.icp.request_plane import (
get_icp_component_id,
set_icp_final_response,
)
NATS_PORT = 4222
# TODO decide if some tests should be removed
# from pre_merge
pytestmark = pytest.mark.pre_merge
def is_port_in_use(port: int) -> bool:
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
@pytest.fixture
def nats_server(request):
command = [
"/usr/local/bin/nats-server",
"--jetstream",
"--debug",
"--trace",
"--port",
str(NATS_PORT),
]
print(f"Running: [{' '.join(command)}]")
# Raise more intuitive error to developer if port is already in-use.
if is_port_in_use(NATS_PORT):
raise RuntimeError(
f"ERROR: NATS Port {NATS_PORT} already in use. Is a nats-server already running?"
)
shutil.rmtree("/tmp/nats", ignore_errors=True)
with open("nats_server.stdout.log", "wt") as output_:
with open("nats_server.stderr.log", "wt") as output_err:
process = subprocess.Popen(
command, stdin=subprocess.DEVNULL, stdout=output_, stderr=output_err
)
time.sleep(1)
yield process
process.terminate()
process.wait()
shutil.rmtree("/tmp/nats", ignore_errors=True)
class ResponseHandler:
def __init__(self, request_plane):
self._request_plane = request_plane
async def response_handler(self, response):
print(response)
request = ModelInferRequest()
request.model_name = response.model_name
request.model_version = response.model_version
print("publishing request")
acks = []
for i in range(5):
acks.append(
self._request_plane.post_request(
request, response_handler=self.response_handler
)
)
asyncio.gather(*acks)
@pytest.mark.timeout(30)
async def test_handler(nats_server):
model_name = str(uuid.uuid1())
model_version = "1"
client_request_plane = NatsRequestPlane()
await client_request_plane.connect()
request = ModelInferRequest()
request.model_name = model_name
request.model_version = model_version
await client_request_plane.post_request(
request, response_handler=ResponseHandler(client_request_plane).response_handler
)
worker_request_plane = NatsRequestPlane()
await worker_request_plane.connect()
request_count = 10
while request_count > 0:
requests = await worker_request_plane.pull_requests(
model_name, model_version, 100, 0.1
)
acks = []
async for request in requests:
request_count -= 1
response = ModelInferResponse()
set_icp_final_response(response, True)
acks.append(worker_request_plane.post_response(request, response))
print(request_count)
await asyncio.gather(*acks)
await asyncio.sleep(0.1)
requests = await worker_request_plane.pull_requests(
model_name, model_version, 100, 0.1
)
await worker_request_plane.close()
await client_request_plane.close()
def run_request_generator(request_queue, response_queue, direct_requests=False):
asyncio.run(
request_generator(
request_queue, response_queue, direct_requests=direct_requests
)
)
async def request_generator(request_queue, response_queue, direct_requests=False):
# Generate requests and wait for responses
# if direct_requests == True, then send all requests to the
# worker that responds to the first request
request_plane = NatsRequestPlane()
await request_plane.connect()
target_component_id = None
while True:
request_bytes = request_queue.get()
if request_bytes is None:
response_queue.put(None)
break
request = ModelInferRequest()
request.ParseFromString(request_bytes)
async for response in await request_plane.post_request(
request, response_iterator=True, component_id=target_component_id
):
if direct_requests:
target_component_id = get_icp_component_id(response)
print(response)
response_queue.put(response.SerializeToString())
def run_worker(model_name, model_version, batch_size, request_count, pull_timeout=0.1):
asyncio.run(
worker(model_name, model_version, batch_size, request_count, pull_timeout)
)
async def worker(
model_name, model_version, batch_size, request_count, pull_timeout=0.1
):
request_plane = NatsRequestPlane()
await request_plane.connect()
while request_count:
requests = await request_plane.pull_requests(
model_name, model_version, batch_size, pull_timeout
)
acks = []
async for request in requests:
print(request)
request_count -= 1
response = ModelInferResponse()
set_icp_final_response(response, True)
acks.append(request_plane.post_response(request, responses=response))
await asyncio.gather(*acks)
@pytest.mark.timeout(30)
async def test_iterator(nats_server):
batch_size = 10
request_count = 100
model_name = str(uuid.uuid1())
model_version = "1"
request_queue: Queue = Queue()
response_queue: Queue = Queue()
generator_process = Process(
target=run_request_generator, args=(request_queue, response_queue)
)
worker_process = Process(
target=run_worker, args=(model_name, model_version, batch_size, request_count)
)
generator_process.start()
worker_process.start()
for index in range(request_count):
request_queue.put(
ModelInferRequest(
model_name=model_name, model_version=model_version, id=str(index)
).SerializeToString()
)
request_queue.put(None)
generator_process.join()
worker_process.join()
response_count = 0
while True:
response = response_queue.get()
if response is None:
break
response_count += 1
assert request_count == response_count
@pytest.mark.parametrize("pull_timeout,batch_size", [(0.1, 10), (None, 1)])
@pytest.mark.timeout(30)
async def test_direct_requests(nats_server, pull_timeout, batch_size):
request_count = 100
model_name = str(uuid.uuid1())
model_version = "1"
request_queue: Queue = Queue()
response_queue: Queue = Queue()
# Note with direct_requests == True
# all requests should target a single worker
# and all responses should be from a single worker
generator_process = Process(
target=run_request_generator,
args=(request_queue, response_queue),
kwargs={"direct_requests": True},
)
worker_process_1 = Process(
target=run_worker,
args=(model_name, model_version, batch_size, request_count, pull_timeout),
)
worker_process_2 = Process(
target=run_worker,
args=(model_name, model_version, batch_size, request_count, pull_timeout),
)
worker_process_1.start()
worker_process_2.start()
time.sleep(1)
generator_process.start()
for index in range(request_count):
request_queue.put(
ModelInferRequest(
model_name=model_name, model_version=model_version, id=str(index)
).SerializeToString()
)
request_queue.put(None)
generator_process.join()
worker_process_1.terminate()
worker_process_1.join()
worker_process_2.terminate()
worker_process_2.join()
response_count = 0
responders = set()
while True:
request_bytes = response_queue.get()
if request_bytes is None:
break
response = ModelInferResponse()
response.ParseFromString(request_bytes)
response_count += 1
responders.add(get_icp_component_id(response))
assert len(responders) == 1
assert request_count == response_count
...@@ -46,7 +46,7 @@ skip = ["build"] ...@@ -46,7 +46,7 @@ skip = ["build"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
minversion = "8.0" minversion = "8.0"
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config", "--mypy"]
xfail_strict = true xfail_strict = true
log_cli_level = "INFO" log_cli_level = "INFO"
filterwarnings = [ filterwarnings = [
...@@ -54,6 +54,11 @@ filterwarnings = [ ...@@ -54,6 +54,11 @@ filterwarnings = [
] ]
# NOTE: Can also manually mark tests with @pytest.mark.asyncio # NOTE: Can also manually mark tests with @pytest.mark.asyncio
asyncio_mode = "auto" asyncio_mode = "auto"
markers = [
"pre_merge: marks tests to run before merging",
"nightly: marks tests to run nightly",
"weekly: marks tests to run weekly"
]
# Linting/formatting # Linting/formatting
[tool.ruff] [tool.ruff]
...@@ -64,7 +69,7 @@ indent-width = 4 ...@@ -64,7 +69,7 @@ indent-width = 4
[tool.mypy] [tool.mypy]
# --disable-error-code: WAR large set of errors due to mypy not being run # --disable-error-code: WAR large set of errors due to mypy not being run
# previously. We can slowly enable sets of errors to fix over time. # previously. We can slowly enable sets of errors to fix over time.
disable_error_code = "arg-type,assignment,attr-defined,call-arg,call-overload,has-type,import-untyped,misc,operator,override,union-attr,var-annotated" #disable_error_code = "arg-type,assignment,attr-defined,call-arg,call-overload,has-type,import-untyped,misc,operator,override,union-attr,var-annotated"
# --explicit-package-bases: WAR errors about duplicate module names used # --explicit-package-bases: WAR errors about duplicate module names used
# throughout project such as launch_workers.py # throughout project such as launch_workers.py
explicit_package_bases = true explicit_package_bases = true
...@@ -72,4 +77,6 @@ explicit_package_bases = true ...@@ -72,4 +77,6 @@ explicit_package_bases = true
# of container environment with PYTHONPATH set and packages installed. # of container environment with PYTHONPATH set and packages installed.
# NOTE: Can possibly move mypy from pre-commit to a github action run only in # NOTE: Can possibly move mypy from pre-commit to a github action run only in
# a container with the expected environment and PYTHONPATH setup. # a container with the expected environment and PYTHONPATH setup.
ignore_missing_imports = true ignore_missing_imports = true
\ No newline at end of file
check_untyped_defs = true
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment