Commit 0bfd9a76 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

refactor: remove python native runtime

parent 8f741f14
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
from dataclasses import field
from typing import Any, AsyncGenerator, List, Optional
import numpy as np
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.runtime import (
Operator,
RemoteInferenceRequest,
RemoteInferenceResponse,
RemoteOperator,
)
from .stages import AggregatedStage, GenerateStage, PrefillStage, Stage
class VllmOperator(Operator):
def __init__(
self,
name: str,
version: int,
request_plane: RequestPlane,
data_plane: DataPlane,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
repository: Optional[str] = None,
logger: Optional[logging.Logger] = None,
triton_core: Optional[Any] = None,
):
self.name = name
self.version = version
self.request_plane = request_plane
self.data_plane = data_plane
if logger is None:
self.logger = logging.getLogger(__name__)
else:
self.logger = logger
self._stage: Stage
self._init_stages(parameters)
async def execute(self, requests: List[RemoteInferenceRequest]) -> None:
for request in requests:
response_sender = request.response_sender()
try:
inputs, parameters = self._prepare_inputs(request)
self.logger.debug("Processing request")
async for response in self._stage(
{
"inputs": inputs,
"parameters": parameters,
}
):
self.logger.debug("Sending response")
await response_sender.send(**response)
self.logger.debug("Response send")
except Exception as e:
self.logger.error(f"Error processing request: {e}")
await response_sender.send(error=e, final=True)
def _init_stages(
self,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
):
args = argparse.Namespace(**parameters) # type: ignore
self._stage = AggregatedStage(
model=args.model_name,
tensor_parallel_size=args.baseline_tp_size,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=args.max_model_len,
dtype=args.dtype,
kv_cache_dtype=args.kv_cache_dtype,
enable_prefix_caching=args.enable_prefix_caching,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
ignore_eos=args.ignore_eos,
max_num_seqs=args.max_num_seqs,
disable_async_output_proc=args.disable_async_output_proc,
disable_log_stats=args.disable_log_stats,
)
@staticmethod
def _prepare_inputs(request: RemoteInferenceRequest):
inputs, parameters = {}, {}
for input_name, input_data in request.inputs.items():
inputs[input_name] = np.from_dlpack(input_data)
for key, value in request.parameters.items():
if isinstance(value, str) and value.startswith("JSON:"):
parameters[key] = json.loads(value[5:])
else:
parameters[key] = value
return inputs, parameters
class VllmContextOperator(VllmOperator):
def _init_stages(
self,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
):
args = argparse.Namespace(**parameters) # type: ignore
self._prefill_stage = PrefillStage(
model=args.model_name,
tensor_parallel_size=args.context_tp_size,
generate_tensor_parallel_size=args.generate_tp_size,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=args.max_model_len,
dtype=args.dtype,
kv_cache_dtype=args.kv_cache_dtype,
enable_prefix_caching=args.enable_prefix_caching,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
ignore_eos=args.ignore_eos,
max_num_seqs=args.max_num_seqs,
disable_async_output_proc=args.disable_async_output_proc,
disable_log_stats=args.disable_log_stats,
)
self._generate_operator = RemoteOperator(
"generate", self.request_plane, self.data_plane
)
async def execute(self, requests: List[RemoteInferenceRequest]) -> None:
for request in requests:
response_sender = request.response_sender()
try:
self.logger.info("Processing request")
inputs, parameters = self._prepare_inputs(request)
responses = [
response
async for response in self._prefill_stage(
{
"inputs": inputs,
"parameters": parameters,
}
)
]
self.logger.info("Prefill finished")
assert len(responses) == 1
response = responses[0]
self.logger.info("Processing generate")
generate_responses: AsyncGenerator[
RemoteInferenceResponse, None
] = await self._generate_operator.async_infer(
inputs=response["outputs"],
parameters={**request.parameters, **response["parameters"]},
)
async for generate_response in generate_responses:
self.logger.info("Sending response")
parameters = {"text": generate_response.parameters["text"]}
await response_sender.send(
outputs=generate_response.outputs,
parameters=parameters,
final=generate_response.final,
error=generate_response.error,
)
self.logger.info("Response send")
except Exception as e:
self.logger.error(f"Error processing request: {e}")
await response_sender.send(error=e, final=True)
class VllmGenerateOperator(VllmOperator):
def _init_stages(
self,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
):
args = argparse.Namespace(**parameters) # type: ignore
args.worker_name = "generate"
self._stage = GenerateStage(
model=args.model_name,
tensor_parallel_size=args.generate_tp_size,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=args.max_model_len,
dtype=args.dtype,
kv_cache_dtype=args.kv_cache_dtype,
enable_prefix_caching=args.enable_prefix_caching,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
ignore_eos=args.ignore_eos,
max_num_seqs=args.max_num_seqs,
disable_async_output_proc=args.disable_async_output_proc,
disable_log_stats=args.disable_log_stats,
)
<!--
SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Event Plane example
A basic example that demonstrates how to use the Event Plane API to create an event plane, register an event, and trigger the event.
## Code overview
### Using context managers (recommended)
```python
async def example_with_context_managers():
server_url = "tls://localhost:4222"
component_id = uuid.uuid4()
async with NatsEventPlane(server_url, component_id) as plane:
received_events = []
async def callback(event):
print(event)
received_events.append(event)
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"my_payload"
# Subscribe using context manager
async with await plane.subscribe(callback, event_topic=event_topic, event_type=event_type):
# Publish event
await plane.publish(event, event_type, event_topic)
# Allow time for message to propagate
await asyncio.sleep(2)
```
### Manual resource management
#### 1) Initialize NATS server and create an event plane
```python
server_url = "tls://localhost:4222" # Optional, default is nats://localhost:4222
component_id = uuid.uuid4() # Optional, component_id will be generated if not given
plane = NatsEventPlane(server_url, component_id)
await plane.connect()
```
#### 2) Define the callback function for receiving events
```python
received_events = []
async def callback(event):
print(event)
received_events.append(event)
```
#### 3) Prepare the event event_topic, event type, and event payload
```python
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"my_payload"
```
#### 4) Subscribe to the event event_topic and type and register the callback function
```python
subscription = await plane.subscribe(callback, event_topic=event_topic, event_type=event_type)
```
#### 5) Publish the event
```python
await plane.publish(event, event_type, event_topic)
```
#### 6) Clean up resources
```python
# Unsubscribe when done
await subscription.unsubscribe()
# Disconnect from NATS server
await plane.disconnect()
```
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from triton_distributed.icp.nats_event_plane import DEFAULT_EVENTS_PORT
def parse_args(args=None):
parser = argparse.ArgumentParser(description="Event Plane Example")
parser.add_argument(
"--nats-port",
type=int,
default=DEFAULT_EVENTS_PORT,
help="Nats server port",
)
parser.add_argument(
"--publisher-count",
type=int,
default=1,
help="Number of publishers to deploy",
)
parser.add_argument(
"--subscriber-count",
type=int,
default=10,
help="Number of subscribers to deploy",
)
args = parser.parse_args(args)
return args
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
from triton_distributed.icp.nats_event_plane import (
EventTopic,
NatsEventPlane,
compose_nats_url,
)
async def single_publisher_subscriber_example():
# async with aclosing(event_plane()) as event_plane_instance:
# event_plane_instance = await anext(event_plane)
server_url = compose_nats_url()
component_id = str(uuid.uuid4())
plane = NatsEventPlane(server_url, component_id)
await plane.connect()
received_events = []
async def callback(event):
print(event)
print(event.payload)
print(event.typed_payload(bytes))
received_events.append(event)
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"my_payload"
await plane.subscribe(callback, event_topic=event_topic, event_type=event_type)
await plane.publish(event, event_type, event_topic)
# Allow time for message to propagate
await asyncio.sleep(3)
print(f"received_events: {received_events}")
# assert received_events[0][0].event_id == event.event_id
await plane.disconnect()
if __name__ == "__main__":
asyncio.run(single_publisher_subscriber_example())
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import uuid
from triton_distributed.icp.nats_event_plane import (
EventTopic,
NatsEventPlane,
compose_nats_url,
)
async def main(args):
server_url = compose_nats_url()
event_plane = NatsEventPlane(server_url, args.component_id)
await event_plane.connect()
try:
event_topic = (
EventTopic(args.event_topic.split(".")) if args.event_topic else None
)
event = args.payload.encode()
await event_plane.publish(event, args.event_type, event_topic)
print(f"Published event from publisher {args.event_topic}")
finally:
await event_plane.disconnect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event publisher script")
parser.add_argument(
"--component-id",
type=uuid.UUID,
default=uuid.uuid4(),
help="Component ID (UUID)",
)
parser.add_argument(
"--event-topic",
type=str,
default=None,
help="Event EventTopic to subscribe to (comma-separated for multiple levels)",
)
parser.add_argument(
"--event-type", type=str, default="test_event", help="Event type"
)
parser.add_argument(
"--payload",
type=str,
default="test_payload",
help="Payload to be published with event.",
)
args = parser.parse_args()
asyncio.run(main(args))
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import uuid
from triton_distributed.icp.nats_event_plane import (
EventTopic,
NatsEventPlane,
compose_nats_url,
)
async def main(args):
server_url = compose_nats_url()
event_plane = NatsEventPlane(server_url, uuid.uuid4())
async def callback(received_event):
print(
f"""
Subscriber {args.subscriber_id}
received event: {received_event.event_id}
event payload: {received_event.payload.tobytes().decode("utf-8")}
event.topic: {received_event.event_topic}
event.type: {received_event.event_type}
event.component_id: {received_event.component_id}
event.timestamp: {received_event.timestamp}
"""
)
await event_plane.connect()
try:
event_topic = (
EventTopic(args.event_topic.split(".")) if args.event_topic else None
)
print(f"Subscribing to event_topic: {args.event_topic}")
await event_plane.subscribe(
callback,
event_topic=event_topic,
event_type=args.event_type,
component_id=args.component_id,
)
print(
f"Subscriber {args.subscriber_id} is listening on event_topic {event_topic} with event type '{args.event_type or 'all'}' "
+ f"component ID '{args.component_id}'"
)
while True:
await asyncio.sleep(5) # Keep the subscriber running
print(f"Subscriber {args.subscriber_id} is still running")
finally:
await event_plane.disconnect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event subscriber script")
parser.add_argument(
"--event-topic",
type=str,
default=None,
help="Event EventTopic to subscribe to (comma-separated for multiple levels)",
)
parser.add_argument(
"--event-type",
type=str,
default=None,
help="Event type to filter (default: None for all types)",
)
parser.add_argument(
"--component-id",
type=uuid.UUID,
default=None,
help="Component ID (UUID) for the subscriber",
)
args = parser.parse_args()
asyncio.run(main(args))
#! /bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PROTO_SRC=$(dirname "$(realpath $0)")
SOURCE_ROOT="$(realpath "${PROTO_SRC}/..")"
PROTO_OUT=$SOURCE_ROOT/python/src/triton_distributed/icp/protos
mkdir -p $PROTO_OUT
python3 -m grpc_tools.protoc -I$PROTO_SRC --python_out=$PROTO_OUT --pyi_out=$PROTO_OUT icp.proto \
&& ls $PROTO_OUT
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package triton.distributed.icp;
//@@
//@@.. cpp:var:: message ModelInferRequest
//@@
//@@ Request message for ModelInfer.
//@@
message ModelInferRequest
{
//@@
//@@ .. cpp:var:: message InferInputTensor
//@@
//@@ An input tensor for an inference request.
//@@
message InferInputTensor
{
//@@
//@@ .. cpp:var:: string name
//@@
//@@ The tensor name.
//@@
string name = 1;
//@@
//@@ .. cpp:var:: string datatype
//@@
//@@ The tensor data type.
//@@
string datatype = 2;
//@@
//@@ .. cpp:var:: int64 shape (repeated)
//@@
//@@ The tensor shape.
//@@
repeated int64 shape = 3;
//@@ .. cpp:var:: map<string,InferParameter> parameters
//@@
//@@ Optional inference input tensor parameters.
//@@
map<string, InferParameter> parameters = 4;
//@@ .. cpp:var:: InferTensorContents contents
//@@
//@@ The tensor contents using a data-type format. This field
//@@ must not be specified if tensor contents are being specified
//@@ in ModelInferRequest.raw_input_contents.
//@@
InferTensorContents contents = 5;
}
//@@
//@@ .. cpp:var:: message InferRequestedOutputTensor
//@@
//@@ An output tensor requested for an inference request.
//@@
message InferRequestedOutputTensor
{
//@@
//@@ .. cpp:var:: string name
//@@
//@@ The tensor name.
//@@
string name = 1;
//@@ .. cpp:var:: map<string,InferParameter> parameters
//@@
//@@ Optional requested output tensor parameters.
//@@
map<string, InferParameter> parameters = 2;
}
//@@ .. cpp:var:: string model_name
//@@
//@@ The name of the model to use for inferencing.
//@@
string model_name = 1;
//@@ .. cpp:var:: string model_version
//@@
//@@ The version of the model to use for inference. If not
//@@ given the latest/most-recent version of the model is used.
//@@
string model_version = 2;
//@@ .. cpp:var:: string id
//@@
//@@ Optional identifier for the request. If specified will be
//@@ returned in the response.
//@@
string id = 3;
//@@ .. cpp:var:: map<string,InferParameter> parameters
//@@
//@@ Optional inference parameters.
//@@
map<string, InferParameter> parameters = 4;
//@@
//@@ .. cpp:var:: InferInputTensor inputs (repeated)
//@@
//@@ The input tensors for the inference.
//@@
repeated InferInputTensor inputs = 5;
//@@
//@@ .. cpp:var:: InferRequestedOutputTensor outputs (repeated)
//@@
//@@ The requested output tensors for the inference. Optional, if not
//@@ specified all outputs specified in the model config will be
//@@ returned.
//@@
repeated InferRequestedOutputTensor outputs = 6;
}
//@@
//@@.. cpp:var:: message ModelInferResponse
//@@
//@@ Response message for ModelInfer.
//@@
message ModelInferResponse
{
//@@
//@@ .. cpp:var:: message InferOutputTensor
//@@
//@@ An output tensor returned for an inference request.
//@@
message InferOutputTensor
{
//@@
//@@ .. cpp:var:: string name
//@@
//@@ The tensor name.
//@@
string name = 1;
//@@
//@@ .. cpp:var:: string datatype
//@@
//@@ The tensor data type.
//@@
string datatype = 2;
//@@
//@@ .. cpp:var:: int64 shape (repeated)
//@@
//@@ The tensor shape.
//@@
repeated int64 shape = 3;
//@@ .. cpp:var:: map<string,InferParameter> parameters
//@@
//@@ Optional output tensor parameters.
//@@
map<string, InferParameter> parameters = 4;
//@@ .. cpp:var:: InferTensorContents contents
//@@
//@@ The tensor contents using a data-type format. This field
//@@ must not be specified if tensor contents are being specified
//@@ in ModelInferResponse.raw_output_contents.
//@@
InferTensorContents contents = 5;
}
//@@ .. cpp:var:: string model_name
//@@
//@@ The name of the model used for inference.
//@@
string model_name = 1;
//@@ .. cpp:var:: string model_version
//@@
//@@ The version of the model used for inference.
//@@
string model_version = 2;
//@@ .. cpp:var:: string id
//@@
//@@ The id of the inference request if one was specified.
//@@
string id = 3;
//@@ .. cpp:var:: map<string,InferParameter> parameters
//@@
//@@ Optional inference response parameters.
//@@
map<string, InferParameter> parameters = 4;
//@@
//@@ .. cpp:var:: InferOutputTensor outputs (repeated)
//@@
//@@ The output tensors holding inference results.
//@@
repeated InferOutputTensor outputs = 5;
}
//@@
//@@.. cpp:var:: message InferParameter
//@@
//@@ An inference parameter value.
//@@
message InferParameter
{
//@@ .. cpp:var:: oneof parameter_choice
//@@
//@@ The parameter value can be a string, an int64,
//@@ an uint64, a double, or a boolean
//@@
//@@ Note: double and uint64 are currently
//@@ placeholders for future use and
//@@ are not supported for custom parameters
//@@
oneof parameter_choice
{
//@@ .. cpp:var:: bool bool_param
//@@
//@@ A boolean parameter value.
//@@
bool bool_param = 1;
//@@ .. cpp:var:: int64 int64_param
//@@
//@@ An int64 parameter value.
//@@
int64 int64_param = 2;
//@@ .. cpp:var:: string string_param
//@@
//@@ A string parameter value.
//@@
string string_param = 3;
//@@ .. cpp:var:: double double_param
//@@
//@@ A double parameter value.
//@@
double double_param = 4;
//@@ .. cpp:var:: uint64 uint64_param
//@@
//@@ A uint64 parameter value.
//@@
//@@ Not supported for custom parameters
//@@
uint64 uint64_param = 5;
}
}
//@@
//@@.. cpp:var:: message InferTensorContents
//@@
//@@ The data contained in a tensor represented by the repeated type
//@@ that matches the tensor's data type. Protobuf oneof is not used
//@@ because oneofs cannot contain repeated fields.
//@@
message InferTensorContents
{
//@@
//@@ .. cpp:var:: bytes bytes_contents (repeated)
//@@
//@@ The size must match what is expected by the tensor's shape.
//@@ The contents must be the flattened, one-dimensional,
//@@ row-major order of the tensor elements.
//@@
repeated bytes bytes_contents = 8;
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[build-system]
requires = ["setuptools>=65.0", "setuptools-scm>=8"]
build-backend = "setuptools.build_meta"
[project]
name = "triton-distributed-icp"
dynamic = ["version"]
authors = [
{ name = "NVIDIA Inc.", email = "sw-dl-triton@nvidia.com" },
]
license = { text = "Apache-2.0" }
# Minimum dependencies to import triton_distributed.icp
# TODO: Expand this list to include all dependencies and remove duplicates from container/deps.
dependencies = [
"cupy-cuda12x",
"nats-py",
"msgspec",
"ucx-py-cu12",
"protobuf==5.27.3",
"grpcio-tools==1.66.0",
]
[tool.setuptools_scm]
version_file = "src/triton_distributed/icp/_version.py"
root = "../.."
[tool.setuptools.packages.find]
where = ["src"]
include = ["triton_distributed.icp*"]
namespaces = true
[tool.setuptools]
license-files = ["../../LICENSE"]
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from triton_distributed.icp.data_plane import DataPlane as DataPlane
from triton_distributed.icp.event_plane import Event as Event
from triton_distributed.icp.event_plane import EventPlane as EventPlane
from triton_distributed.icp.event_plane import EventTopic as EventTopic
from triton_distributed.icp.nats_event_plane import (
DEFAULT_EVENTS_HOST as DEFAULT_EVENTS_HOST,
)
from triton_distributed.icp.nats_event_plane import (
DEFAULT_EVENTS_PORT as DEFAULT_EVENTS_PORT,
)
from triton_distributed.icp.nats_event_plane import NatsEventPlane as NatsEventPlane
from triton_distributed.icp.nats_request_plane import (
NatsRequestPlane as NatsRequestPlane,
)
from triton_distributed.icp.nats_request_plane import NatsServer as NatsServer
from triton_distributed.icp.request_plane import RequestPlane as RequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane as UcpDataPlane
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Type
class CustomKeyErrorDict(dict):
def __init__(
self,
from_name: str,
to_name: str,
*args,
exception: Type[Exception] = ValueError,
**kwargs,
):
super().__init__(*args, **kwargs)
self._to_name = to_name
self._from_name = from_name
self._exception = exception
def __getitem__(self, key):
try:
return super().__getitem__(key)
except KeyError:
raise self._exception(
f"Unsupported {self._from_name}. Can't convert {key} to {self._to_name}"
) from None
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
# This file contains the DLPack API wrapped in Python style (see
# 'dlpack.h' for detail) and the utilities for Triton client to interact
# with DLPack
#
# Ref:
# https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h
# https://github.com/dmlc/dlpack/blob/main/apps/numpy_dlpack/dlpack/from_numpy.py
################################################################################
import ctypes
from typing import Union
from triton_distributed.icp._custom_key_error_dict import CustomKeyErrorDict
from triton_distributed.icp.data_type import DataType
from triton_distributed.icp.memory_type import MemoryType, string_to_memory_type
try:
import cupy
except ImportError:
cupy = None
# Need to explicit set the res / arg types for pythonapi functions to
# work properly
ctypes.pythonapi.PyMem_RawMalloc.restype = ctypes.c_void_p
ctypes.pythonapi.PyMem_RawFree.argtypes = [ctypes.c_void_p]
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
ctypes.pythonapi.PyCapsule_New.argtypes = [
ctypes.c_void_p,
ctypes.c_char_p,
ctypes.c_void_p,
]
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p]
c_str_dltensor = b"dltensor"
class DLDeviceType(ctypes.c_int):
kDLCPU = 1
kDLCUDA = 2
kDLCUDAHost = 3
kDLOpenCL = 4
kDLVulkan = 7
kDLMetal = 8
kDLVPI = 9
kDLROCM = 10
kDLROCMHost = 11
kDLExtDev = 12
kDLCUDAManaged = 13
kDLOneAPI = 14
kDLWebGPU = 15
kDLHexagon = 16
DeviceOrMemoryType = Union[
tuple[MemoryType, int], MemoryType, tuple[DLDeviceType, int], str
]
class DLDevice(ctypes.Structure):
_fields_ = [
("device_type", ctypes.c_int),
("device_id", ctypes.c_int),
]
class DLDataTypeCode(ctypes.c_uint8):
kDLInt = 0
kDLUInt = 1
kDLFloat = 2
kDLOpaquePointer = 3
kDLBfloat = 4
kDLComplex = 5
kDLBool = 6
class DLDataType(ctypes.Structure):
_fields_ = [
("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16),
]
class DLTensor(ctypes.Structure):
_fields_ = [
("data", ctypes.c_void_p),
("device", DLDevice),
("ndim", ctypes.c_int),
("dtype", DLDataType),
("shape", ctypes.POINTER(ctypes.c_int64)),
("strides", ctypes.POINTER(ctypes.c_int64)),
("byte_offset", ctypes.c_uint64),
]
class DLManagedTensor(ctypes.Structure):
_fields_ = [
("dl_tensor", DLTensor),
("manager_ctx", ctypes.c_void_p),
("deleter", ctypes.CFUNCTYPE(None, ctypes.c_void_p)),
]
# Utilities
def _raise_error(msg):
"""
Raise error with the provided message
"""
raise Exception(msg) from None
# Use as managed context in DLPack that doesn't hold ownership of the
# data content.
class DataViewContext:
def __init__(self, shape) -> None:
# Convert the Python object to ctypes objects expected by
# DLPack
self._shape = (ctypes.c_int64 * len(shape))(*shape)
# No strides: compact and row-major
self._strides = ctypes.POINTER(ctypes.c_int64)()
def as_manager_ctx(self) -> ctypes.c_void_p:
py_obj = ctypes.py_object(self)
py_obj_ptr = ctypes.pointer(py_obj)
ctypes.pythonapi.Py_IncRef(py_obj)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(py_obj_ptr))
return ctypes.cast(py_obj_ptr, ctypes.c_void_p)
@ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def managed_tensor_deleter(handle: ctypes.c_void_p) -> None:
dl_managed_tensor = DLManagedTensor.from_address(handle) # type: ignore
py_obj_ptr = ctypes.cast(
dl_managed_tensor.manager_ctx, ctypes.POINTER(ctypes.py_object)
)
py_obj = py_obj_ptr.contents
ctypes.pythonapi.Py_DecRef(py_obj)
ctypes.pythonapi.Py_DecRef(ctypes.py_object(py_obj_ptr))
ctypes.pythonapi.PyMem_RawFree(handle)
@ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def pycapsule_deleter(handle: ctypes.c_void_p) -> None:
pycapsule: ctypes.py_object = ctypes.cast(handle, ctypes.py_object)
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, c_str_dltensor):
dl_managed_tensor = ctypes.pythonapi.PyCapsule_GetPointer(
pycapsule, c_str_dltensor
)
managed_tensor_deleter(dl_managed_tensor)
ctypes.pythonapi.PyCapsule_SetDestructor(pycapsule, None)
def is_contiguous_data(
ndim: ctypes.c_int,
shape: ctypes.POINTER(ctypes.c_int64), # type: ignore
stride: ctypes.POINTER(ctypes.c_int64), # type: ignore
):
# If 'stride' doesn't capture valid value
if (stride is None) or (not bool(stride)):
return True
calculated_stride = 1
# iterate stride in reverse order [ndim-1, -1)
for i in reversed(range(ndim)): # type: ignore
if stride[i] != calculated_stride:
return False
calculated_stride *= shape[i]
return True
def get_byte_size(
dtype: DLDataType, ndim: ctypes.c_int, shape: ctypes.POINTER(ctypes.c_int64) # type: ignore
):
element_byte_size = dtype.bits * dtype.lanes // 8 # Assume 8 bits in a byte
for i in range(ndim): # type: ignore
element_byte_size *= shape[i]
return element_byte_size
def get_dlpack_capsule(dlpack_obj, stream=None):
# Extract PyCapsule of the DLPack object
if hasattr(dlpack_obj, "__dlpack__"):
if not hasattr(dlpack_obj, "__dlpack_device__"):
_raise_error(
"DLPack expects '__dlpack_device__' if '__dlpack__' has been defined"
)
device = dlpack_obj.__dlpack_device__()
# Have to condition on the device type as, using numpy as example,
# some DLPack implementation doesn't accept 'stream' as arguments
if device != DLDeviceType.kDLCUDA:
return dlpack_obj.__dlpack__()
else:
return dlpack_obj.__dlpack__(stream)
else:
# Old interface where PyCapsule object is passed directly
return dlpack_obj
def get_dlpack_device(dlpack_obj):
if hasattr(dlpack_obj, "__dlpack_device__"):
return dlpack_obj.__dlpack_device__()
return None
def get_managed_tensor(dlcapsule):
ptr = ctypes.pythonapi.PyCapsule_GetPointer(dlcapsule, c_str_dltensor)
return DLManagedTensor.from_address(ptr)
class DLPackObject:
def __init__(self, value) -> None:
try:
stream = None
device, device_id = value.__dlpack_device__()
if device == DLDeviceType.kDLCUDA:
if cupy is None:
raise ValueError(
f"DLPack synchronization on device {device,device_id} not supported"
)
with cupy.cuda.Device(device_id):
stream = 1 # legacy default stream
self._capsule = get_dlpack_capsule(value, stream)
self._tensor = get_managed_tensor(self._capsule).dl_tensor
else:
self._capsule = get_dlpack_capsule(value)
self._tensor = get_managed_tensor(self._capsule).dl_tensor
except Exception as e:
raise ValueError(f"Object does not support DLPack protocol: {e}") from None
def __eq__(self, other) -> bool:
if not isinstance(other, DLPackObject):
return False
if self.byte_size != other.byte_size:
return False
if self.memory_type != other.memory_type:
return False
if self.memory_type_id != other.memory_type_id:
return False
if self.shape != other.shape:
return False
if self.data_ptr != other.data_ptr:
return False
if self.contiguous != other.contiguous:
return False
if self.data_type != other.data_type:
return False
return True
@property
def byte_size(self) -> int:
return get_byte_size(self._tensor.dtype, self._tensor.ndim, self._tensor.shape)
@property
def memory_type(self) -> MemoryType:
return DLPACK_DEVICE_TYPE_TO_MEMORY_TYPE[self._tensor.device.device_type]
@property
def memory_type_id(self) -> int:
return self._tensor.device.device_id
@property
def shape(self) -> list[int]:
return [self._tensor.shape[i] for i in range(self._tensor.ndim)]
@property
def data_type(self) -> DataType:
return DLPACK_TO_DATA_TYPE[self.dlpack_data_type]
@property
def dlpack_data_type(self) -> tuple[DLDataTypeCode, int]:
return (self._tensor.dtype.type_code, self._tensor.dtype.bits)
@property
def data_ptr(self) -> ctypes.c_void_p:
return self._tensor.data + self._tensor.byte_offset
@property
def contiguous(self) -> bool:
return is_contiguous_data(
self._tensor.ndim, self._tensor.shape, self._tensor.strides
)
DLPACK_DEVICE_TYPE_TO_MEMORY_TYPE: dict[DLDeviceType, MemoryType] = CustomKeyErrorDict(
"DLPack device type",
"Memory type",
{
DLDeviceType.kDLCUDA: MemoryType.GPU,
DLDeviceType.kDLCPU: MemoryType.CPU,
},
)
MEMORY_TYPE_TO_DLPACK_DEVICE_TYPE: dict[MemoryType, DLDeviceType] = CustomKeyErrorDict(
"Memory type",
"DLPack device type",
{
**{value: key for key, value in DLPACK_DEVICE_TYPE_TO_MEMORY_TYPE.items()},
**{MemoryType.CPU_PINNED: DLDeviceType.kDLCPU},
},
)
def parse_device_or_memory_type(
device_or_memory_type: DeviceOrMemoryType,
) -> tuple[MemoryType, int]:
memory_type = None
memory_type_id = 0
if isinstance(device_or_memory_type, tuple):
if isinstance(device_or_memory_type[0], MemoryType):
memory_type = device_or_memory_type[0]
memory_type_id = device_or_memory_type[1]
elif isinstance(device_or_memory_type[0], DLDeviceType):
memory_type = DLPACK_DEVICE_TYPE_TO_MEMORY_TYPE[device_or_memory_type[0]]
memory_type_id = device_or_memory_type[1]
else:
raise ValueError(f"Invalid memory type {device_or_memory_type}")
elif isinstance(device_or_memory_type, MemoryType):
memory_type = device_or_memory_type
memory_type_id = 0
elif isinstance(device_or_memory_type, str):
memory_str_tuple = device_or_memory_type.split(":")
if len(memory_str_tuple) > 2:
raise ValueError(f"Invalid memory type string {device_or_memory_type}")
memory_type = string_to_memory_type(memory_str_tuple[0].upper())
if len(memory_str_tuple) == 2:
try:
memory_type_id = int(memory_str_tuple[1])
except ValueError:
raise ValueError(
f"Invalid memory type string {device_or_memory_type}"
) from None
else:
memory_type_id = 0
return (memory_type, memory_type_id)
DLPACK_TO_DATA_TYPE: dict[tuple[DLDataTypeCode, int], DataType] = CustomKeyErrorDict(
"DLPack data type",
"Data type",
{
(DLDataTypeCode.kDLBool, 8): DataType.BOOL,
(DLDataTypeCode.kDLInt, 8): DataType.INT8,
(
DLDataTypeCode.kDLInt,
16,
): DataType.INT16,
(
DLDataTypeCode.kDLInt,
32,
): DataType.INT32,
(
DLDataTypeCode.kDLInt,
64,
): DataType.INT64,
(
DLDataTypeCode.kDLUInt,
8,
): DataType.UINT8,
(
DLDataTypeCode.kDLUInt,
16,
): DataType.UINT16,
(
DLDataTypeCode.kDLUInt,
32,
): DataType.UINT32,
(
DLDataTypeCode.kDLUInt,
64,
): DataType.UINT64,
(
DLDataTypeCode.kDLFloat,
16,
): DataType.FP16,
(
DLDataTypeCode.kDLFloat,
32,
): DataType.FP32,
(
DLDataTypeCode.kDLFloat,
64,
): DataType.FP64,
(
DLDataTypeCode.kDLBfloat,
16,
): DataType.BF16,
},
)
DATA_TYPE_TO_DLPACK_DTYPE: dict[DataType, DLDataType] = CustomKeyErrorDict(
"Data type",
"DLPack data type",
{
value: DLDataType(type_code=key[0], bits=key[1], lanes=1)
for key, value in DLPACK_TO_DATA_TYPE.items()
},
)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract Class for interacting with Triton Distributed Inter-Component Protocol Data Plane"""
import abc
import uuid
from typing import Optional, Sequence
import cupy
import numpy
from triton_distributed.icp.data_type import (
DATA_TYPE_TO_NUMPY_DTYPE,
DataType,
string_to_data_type,
)
from triton_distributed.icp.memory_buffer import MemoryBuffer
from triton_distributed.icp.memory_type import MemoryType, string_to_memory_type
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from triton_distributed.icp.tensor import Tensor
class DataPlaneError(Exception):
pass
ICP_TENSOR_URI = "icp_tensor_uri"
ICP_MEMORY_TYPE = "icp_memory_type"
ICP_MEMORY_TYPE_ID = "icp_memory_type_id"
ICP_TENSOR_SIZE = "icp_tensor_size"
def set_icp_shape(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: Sequence[int],
) -> None:
for dim in value:
message.shape.append(dim)
def get_icp_shape(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> Sequence[int]:
return message.shape
def set_icp_data_type(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: DataType,
) -> None:
message.datatype = value.name
def get_icp_data_type(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> DataType:
return string_to_data_type(message.datatype)
def set_icp_tensor_uri(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: str,
) -> None:
message.parameters[ICP_TENSOR_URI].string_param = value
def get_icp_tensor_uri(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> str | None:
if ICP_TENSOR_URI not in message.parameters:
return None
return message.parameters[ICP_TENSOR_URI].string_param
def set_icp_tensor_size(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: int,
) -> None:
message.parameters[ICP_TENSOR_SIZE].uint64_param = value
def get_icp_tensor_size(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> int | None:
if ICP_TENSOR_SIZE not in message.parameters:
return None
return message.parameters[ICP_TENSOR_SIZE].uint64_param
def set_icp_memory_type(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: MemoryType,
) -> None:
message.parameters[ICP_MEMORY_TYPE].string_param = value.name
def get_icp_memory_type(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> MemoryType | None:
if ICP_MEMORY_TYPE not in message.parameters:
return None
return string_to_memory_type(message.parameters[ICP_MEMORY_TYPE].string_param)
def set_icp_memory_type_id(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
value: int,
) -> None:
message.parameters[ICP_MEMORY_TYPE_ID].int64_param = value
def get_icp_memory_type_id(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> int | None:
if ICP_MEMORY_TYPE_ID not in message.parameters:
return None
return message.parameters[ICP_MEMORY_TYPE_ID].int64_param
def set_icp_tensor_contents(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
tensor: Tensor,
) -> None:
set_icp_memory_type(message, MemoryType.CPU)
set_icp_memory_type_id(message, 0)
set_icp_tensor_size(message, tensor.size)
if tensor.data_type == DataType.BYTES:
array = tensor.to_bytes_array()
for i in list(array.flat):
message.contents.bytes_contents.append(i)
else:
if tensor.memory_type == MemoryType.CPU:
# Directly use the memory buffer when contents on the CPU.
array = tensor.memory_buffer.owner
elif tensor.memory_type == MemoryType.GPU:
with cupy.cuda.Device(tensor.memory_buffer.memory_type_id):
array = cupy.from_dlpack(tensor)
else:
raise ValueError(f"Invalid Tensor Memory Type {tensor.memory_type}")
message.contents.bytes_contents.append(array.tobytes())
def get_icp_tensor_contents(
message: ModelInferRequest.InferInputTensor | ModelInferResponse.InferOutputTensor,
) -> Tensor | None:
if not message.HasField("contents"):
# Return None if the content is not part of message
return None
datatype = get_icp_data_type(message)
shape = get_icp_shape(message)
tensor = None
if datatype == DataType.BYTES:
array = numpy.array(
[
message.contents.bytes_contents[i]
for i in range(len(message.contents.bytes_contents))
]
)
array = numpy.reshape(array, shape)
tensor = Tensor.from_bytes_array(array)
else:
array = numpy.array(
numpy.frombuffer(
message.contents.bytes_contents[0],
dtype=DATA_TYPE_TO_NUMPY_DTYPE[datatype],
)
)
tensor = Tensor(datatype, shape, MemoryBuffer.from_dlpack(array))
return tensor
class DataPlane(abc.ABC):
def __init__(self) -> None:
pass
@abc.abstractmethod
def connect(self) -> None:
pass
@abc.abstractmethod
def put_input_tensor(
self, tensor: Tensor, tensor_id: Optional[uuid.UUID], use_tensor_contents: bool
) -> ModelInferRequest.InferInputTensor:
pass
@abc.abstractmethod
def put_output_tensor(
self, tensor: Tensor, tensor_id: Optional[uuid.UUID], use_tensor_contents: bool
) -> ModelInferResponse.InferOutputTensor:
pass
@abc.abstractmethod
def get_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
requested_memory_type: Optional[MemoryType] = None,
requested_memory_type_id: Optional[int] = None,
) -> Tensor:
pass
@abc.abstractmethod
def create_input_tensor_reference(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> ModelInferRequest.InferInputTensor:
pass
@abc.abstractmethod
def create_output_tensor_reference(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> ModelInferResponse.InferOutputTensor:
pass
@abc.abstractmethod
def release_tensor(
self,
remote_tensor: ModelInferRequest.InferInputTensor
| ModelInferResponse.InferOutputTensor,
) -> None:
pass
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from enum import IntEnum
import numpy
from triton_distributed.icp._custom_key_error_dict import CustomKeyErrorDict
DataType = IntEnum(
"DataType",
names=(
"INVALID",
"BOOL",
"UINT8",
"UINT16",
"UINT32",
"UINT64",
"INT8",
"INT16",
"INT32",
"INT64",
"FP16",
"FP32",
"FP64",
"BYTES",
"BF16",
),
start=0,
)
def string_to_data_type(data_type_string: str) -> DataType:
try:
return DataType[data_type_string]
except KeyError:
raise ValueError(
f"Unsupported Data Type String. Can't convert {data_type_string} to DataType"
) from None
NUMPY_TO_DATA_TYPE: dict[type, DataType] = CustomKeyErrorDict(
"Numpy dtype",
"Data type",
{
bool: DataType.BOOL,
numpy.bool_: DataType.BOOL,
numpy.int8: DataType.INT8,
numpy.int16: DataType.INT16,
numpy.int32: DataType.INT32,
numpy.int64: DataType.INT64,
numpy.uint8: DataType.UINT8,
numpy.uint16: DataType.UINT16,
numpy.uint32: DataType.UINT32,
numpy.uint64: DataType.UINT64,
numpy.float16: DataType.FP16,
numpy.float32: DataType.FP32,
numpy.float64: DataType.FP64,
numpy.bytes_: DataType.BYTES,
numpy.str_: DataType.BYTES,
numpy.object_: DataType.BYTES,
},
)
DATA_TYPE_TO_NUMPY_DTYPE: dict[DataType, type] = CustomKeyErrorDict(
"Data type",
"Numpy dtype",
{
**{value: key for key, value in NUMPY_TO_DATA_TYPE.items()},
**{DataType.BYTES: numpy.object_},
**{DataType.BOOL: numpy.bool_},
},
)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import re
import uuid
from abc import abstractmethod
from datetime import datetime
from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Type, Union
EVENT_TOPIC_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$")
def _validate_topics(topics: List[str]) -> bool:
"""
Checks if all strings in the list are alphanumeric and can contain underscores (_) and hyphens (-).
:param subjects: List of strings to validate
:return: True if all strings are valid, False otherwise
"""
pattern = EVENT_TOPIC_PATTERN
return all(pattern.match(topic) for topic in topics)
@dataclasses.dataclass
class EventTopic:
"""Event event_topic class for identifying event streams."""
event_topic: str
def __init__(self, event_topic: Union[List[str], str]):
"""Initialize the event_topic.
Args:
event_topic (Union[List[str], str]): The event_topic as a list of strings or a single string. Strings should be alphanumeric + underscore and '-' characters only. The list forms a hierarchy of topics.
"""
if isinstance(event_topic, str):
if "." in event_topic:
event_topic_list = event_topic.split(".")
else:
event_topic_list = [event_topic]
else:
event_topic_list = event_topic
if not _validate_topics(event_topic_list):
raise ValueError(
"Invalid event_topic. Only alphanumeric characters, underscores, and hyphens are allowed."
)
event_topic_string = ".".join(event_topic_list)
self.event_topic = event_topic_string
def __str__(self):
return self.event_topic
class Event:
"""Event class for representing events."""
@property
@abstractmethod
def event_id(self) -> uuid.UUID:
pass
@property
@abstractmethod
def event_type(self) -> str:
pass
@property
@abstractmethod
def timestamp(self) -> datetime:
pass
@property
@abstractmethod
def component_id(self) -> uuid.UUID:
pass
@property
@abstractmethod
def event_topic(self) -> Optional[EventTopic]:
pass
@property
@abstractmethod
def payload(self) -> bytes:
pass
@abstractmethod
def typed_payload(self, payload_type: Optional[Type | str] = None) -> Any:
pass
class EventSubscription(AsyncIterator[Event]):
@abstractmethod
async def __anext__(self) -> Event:
pass
@abstractmethod
def __aiter__(self):
return self
@abstractmethod
def unsubscribe(self):
pass
class EventPlane:
"""EventPlane interface for publishing and subscribing to events."""
@abstractmethod
async def connect(self):
"""Connect to the event plane."""
pass
@abstractmethod
async def publish(
self,
event: Union[bytes, Any],
event_type: str,
event_topic: Optional[EventTopic],
) -> Event:
"""Publish an event to the event plane.
Args:
event (Union[bytes, Any]): Event payload
event_type (str): Event type
event_topic (Optional[EventTopic]): Event event_topic
"""
pass
@abstractmethod
async def subscribe(
self,
callback: Callable[[Event], Awaitable[None]],
event_topic: Optional[EventTopic] = None,
event_type: Optional[str] = None,
component_id: Optional[uuid.UUID] = None,
) -> EventSubscription:
"""Subscribe to events on the event plane.
Args:
callback (Callable[[bytes, bytes], Awaitable[None]]): Callback function to be called when an event is received
event_topic (Optional[EventTopic]): Event event_topic
event_type (Optional[str]): Event type
component_id (Optional[uuid.UUID]): Component ID
"""
pass
@abstractmethod
async def disconnect(self):
"""Disconnect from the event plane."""
pass
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from triton_distributed.icp._dlpack import DLPackObject
from triton_distributed.icp.memory_type import MemoryType
@dataclass
class MemoryBuffer:
"""Memory allocated for a Tensor.
This object does not own the memory but holds a reference to the
owner.
Parameters
----------
data_ptr : int
Pointer to the allocated memory.
memory_type : MemoryType
memory type
memory_type_id : int
memory type id (typically the same as device id)
size : int
Size of the allocated memory in bytes.
owner : Any
Object that owns or manages the memory buffer. Allocated
memory must not be freed while a reference to the owner is
held.
Examples
--------
>>> buffer = MemoryBuffer.from_dlpack(numpy.array([100],dtype=numpy.uint8))
"""
data_ptr: int
memory_type: MemoryType
memory_type_id: int
size: int
owner: Any
@staticmethod
def from_dlpack(owner: Any) -> MemoryBuffer:
if not hasattr(owner, "__dlpack__"):
raise ValueError("Object does not support DLpack protocol")
dlpack_object = DLPackObject(owner)
return MemoryBuffer._from_dlpack_object(owner, dlpack_object)
@staticmethod
def _from_dlpack_object(owner: Any, dlpack_object: DLPackObject) -> MemoryBuffer:
if not dlpack_object.contiguous:
raise ValueError("Only contiguous memory is supported")
return MemoryBuffer(
int(dlpack_object.data_ptr),
dlpack_object.memory_type,
dlpack_object.memory_type_id,
dlpack_object.byte_size,
owner,
)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import IntEnum
MemoryType = IntEnum("MemoryType", names=("CPU", "CPU_PINNED", "GPU"), start=0)
def string_to_memory_type(memory_type_string: str) -> MemoryType:
try:
return MemoryType[memory_type_string]
except KeyError:
raise ValueError(
f"Unsupported Memory Type String. Can't convert {memory_type_string} to MemoryType"
) from None
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import datetime
import logging
import os
import uuid
from typing import Any, Awaitable, Callable, List, Optional, Union
import msgspec
import nats
from triton_distributed.icp import EventTopic
from triton_distributed.icp.event_plane import Event, EventSubscription
from triton_distributed.icp.on_demand_event import (
EventMetadata,
OnDemandEvent,
_serialize_metadata,
)
logger = logging.getLogger(__name__)
DEFAULT_EVENTS_PORT = int(os.getenv("DEFAULT_EVENTS_PORT", 4222))
DEFAULT_EVENTS_HOST = os.getenv("DEFAULT_EVENTS_HOST", "localhost")
DEFAULT_EVENTS_PROTOCOL = os.getenv("DEFAULT_EVENTS_PROTOCOL", "nats")
DEFAULT_CONNECTION_TIMEOUT = int(os.getenv("DEFAULT_CONNECTION_TIMEOUT", 30))
EVENT_PLANE_NATS_PREFIX = "event_plane_nats_v1"
def compose_nats_uri(
protocol: str = DEFAULT_EVENTS_PROTOCOL,
host: str = DEFAULT_EVENTS_HOST,
port: int = DEFAULT_EVENTS_PORT,
) -> str:
"""Compose a NATS URL from components.
Args:
protocol: The protocol to use (tls or nats). Defaults to DEFAULT_EVENTS_PROTOCOL.
host: The host to connect to. Defaults to DEFAULT_EVENTS_HOST.
port: The port to connect to. Defaults to DEFAULT_EVENTS_PORT.
Returns:
str: The composed NATS URL
"""
return f"{protocol}://{host}:{port}"
class NatsEventSubscription(EventSubscription):
def __init__(
self,
nc_sub: nats.aio.subscription.Subscription,
nats_connection: Any,
subject: str,
topic: EventTopic,
):
self._nc_sub: Optional[nats.aio.subscription.Subscription] = nc_sub
self._nats = nats_connection
self._subject = subject
self._topic = topic
self._unsubscribe_event: asyncio.Event = asyncio.Event()
async def __anext__(self):
if self._nc_sub is None:
raise StopAsyncIteration
if not self._nats.is_connected:
if self._error is not None:
raise RuntimeError(
f"NATS connection error: {self._error}"
) from self._error
else:
raise RuntimeError("NATS connection failure.")
else:
failure_task = asyncio.create_task(self._nats.wait_for_failure())
next_task = asyncio.create_task(self._nc_sub.next_msg())
_ = await asyncio.wait(
[next_task, failure_task], return_when=asyncio.FIRST_COMPLETED
)
if failure_task.done():
logger.warning("NATS connection failure.")
try:
next_task.cancel()
await next_task
except asyncio.CancelledError:
pass
raise RuntimeError("NATS connection failure.") from failure_task.exception()
else:
try:
failure_task.cancel()
await failure_task
except asyncio.CancelledError:
pass
msg = next_task.result()
metadata, event_payload = NatsEventPlane._extract_metadata_and_payload(
msg.data
)
event = OnDemandEvent(event_payload, metadata)
return event
def __aiter__(self):
return self
async def unsubscribe(self):
if self._nc_sub is None:
return
if self._nats.is_connected():
await self._nc_sub.unsubscribe()
self._nc_sub = None
else:
logger.warning("NATS not connected. Cannot unsubscribe.")
@property
def subject(self):
return self._subject
@property
def topic(self):
return self._topic
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.unsubscribe()
return False # Don't suppress exceptions
class NatsEventPlane:
"""EventPlane implementation using NATS."""
def __init__(
self,
server_uri: str = compose_nats_uri(),
component_id: Optional[uuid.UUID] = uuid.uuid4(),
run_callback_in_parallel: bool = False,
):
"""Initialize the NATS event plane.
Args:
server_uri: URI of the NATS server. If None, will be composed using environment variables.
component_id: Component ID.
"""
self._run_callback_in_parallel = run_callback_in_parallel
if server_uri is None:
server_uri = compose_nats_uri()
self._server_uri = server_uri
if component_id is None:
component_id = uuid.uuid4()
self._component_id = component_id
self._nc = nats.NATS()
self._error: Optional[Exception] = None
self._connected = False
self._failure_event: Optional[asyncio.Event] = None
async def wait_for_failure(self):
"""Wait for a failure event."""
if self._failure_event is not None:
await self._failure_event.wait()
raise RuntimeError("NATS connection failure.") from self._error
else:
raise RuntimeError("NATS connection failure event is None")
def is_connected(self):
return self._connected
async def connect(self):
"""Connect to the NATS server."""
if self._connected:
return
async def error_cb(e):
logger.warning("NATS error: %s", e)
self._error = e
if self._failure_event is not None:
self._failure_event.set()
self._failure_event = asyncio.Event()
else:
logger.error(f"NATS connection failure event is None for error {e}")
async def reconnected_cb():
logger.debug("NATS reconnected")
self._connected = True
async def disconnected_cb():
logger.debug("NATS disconnected")
self._connected = False
async def closed_cb():
logger.debug("NATS closed")
self._connected = False
self._failure_event = asyncio.Event()
try:
async with asyncio.timeout(DEFAULT_CONNECTION_TIMEOUT):
logger.debug(f"Connecting to NATS server: {self._server_uri}")
connect_task = asyncio.create_task(
self._nc.connect(
self._server_uri,
error_cb=error_cb,
reconnected_cb=reconnected_cb,
disconnected_cb=disconnected_cb,
closed_cb=closed_cb,
)
)
failed_task = asyncio.create_task(self.wait_for_failure())
await asyncio.wait(
[connect_task, failed_task], return_when=asyncio.FIRST_COMPLETED
)
if failed_task.done():
try:
connect_task.cancel()
await connect_task
except asyncio.CancelledError:
pass
raise RuntimeError(
"NATS connection failure."
) from failed_task.exception()
else:
try:
failed_task.cancel()
await failed_task
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
raise RuntimeError(
f"NATS connection timeout {DEFAULT_CONNECTION_TIMEOUT} reached."
)
logger.debug(f"Connected to NATS server: {self._server_uri}")
self._connected = True
async def publish(
self,
payload: bytes | Any,
event_type: Optional[str] = None,
event_topic: Optional[EventTopic | str | List[str]] = None,
timestamp: Optional[datetime.datetime] = datetime.datetime.now(datetime.UTC),
event_id: Optional[uuid.UUID] = uuid.uuid4(),
) -> Event:
"""Publish an event to the NATS server.
Args:
payload: Event payload.
event_type: Type of the event.
event_topic: EventTopic of the event.
"""
if not self._connected:
if self._error:
raise RuntimeError(
f"NATS connection error: {self._error}"
) from self._error
else:
raise RuntimeError("NATS not connected.")
if timestamp is None:
timestamp = datetime.datetime.now(datetime.UTC)
if event_id is None:
event_id = uuid.uuid4()
if event_topic is not None and not isinstance(event_topic, EventTopic):
event_topic = EventTopic(event_topic)
event_metadata = EventMetadata(
event_id=event_id,
event_topic=event_topic,
event_type=event_type if event_type else str(type(payload).__name__),
timestamp=timestamp,
component_id=self._component_id,
)
metadata_serialized = _serialize_metadata(event_metadata)
metadata_size = len(metadata_serialized).to_bytes(4, byteorder="big")
# Concatenate metadata size, metadata, and event payload
if isinstance(payload, bytes):
message = metadata_size + metadata_serialized + payload
else:
message = metadata_size + metadata_serialized + msgspec.json.encode(payload)
subject = self._compose_publish_subject(event_metadata)
await self._nc.publish(subject, message)
event_with_metadata = OnDemandEvent(
payload, metadata_serialized, event_metadata
)
return event_with_metadata
async def subscribe(
self,
callback: Optional[Callable[[Event], Awaitable[None]]] = None,
event_topic: Optional[EventTopic | str | List[str]] = None,
event_type: Optional[str] = "*",
component_id: Optional[uuid.UUID] = None,
) -> EventSubscription:
"""Subscribe to events on the NATS server.
Args:
callback: Callback function to be called when an event is received.
event_topic: Event event_topic.
event_type: Event type.
component_id: Component ID.
"""
if not self._connected:
if self._error:
raise RuntimeError(
f"NATS connection error: {self._error}"
) from self._error
else:
raise RuntimeError("NATS not connected.")
async def _message_handler(msg):
metadata, event_payload = NatsEventPlane._extract_metadata_and_payload(
msg.data
)
event = OnDemandEvent(event_payload, metadata)
async def wrapper():
if callback is not None:
await callback(event) # Ensure it's a proper coroutine
if self._run_callback_in_parallel:
if callback is not None:
asyncio.create_task(wrapper()) # Run in parallel
else:
if callback is not None:
await callback(event) # Await normally
subject_str, topic = self._compose_subscribe_subject(
event_topic, event_type, component_id
)
_cb = _message_handler if callback is not None else None
sub = await self._nc.subscribe(subject_str, cb=_cb)
event_sub = NatsEventSubscription(sub, self, subject_str, topic)
return event_sub
async def disconnect(self):
"""Disconnect from the NATS server."""
if not self._connected:
return
await self._nc.close()
self._error = RuntimeError("NATS connection closed by disconnect.")
if self._failure_event is not None:
self._failure_event.set()
self._connected = False
def _compose_publish_subject(self, event_metadata: EventMetadata):
return f"{EVENT_PLANE_NATS_PREFIX}.{event_metadata.event_type}.{event_metadata.component_id}.{str(event_metadata.event_topic) + '.' if event_metadata.event_topic else ''}trunk"
def _compose_subscribe_subject(
self,
event_topic: Optional[Union[EventTopic, str, List[str]]] = None,
event_type: Optional[str] = None,
component_id: Optional[uuid.UUID] = None,
):
if isinstance(event_topic, str) or isinstance(event_topic, list):
event_topic_obj = EventTopic(event_topic)
else:
event_topic_obj = event_topic
return (
f"{EVENT_PLANE_NATS_PREFIX}.{event_type or '*'}.{component_id or '*'}.{str(event_topic_obj) + '.' if event_topic else ''}>",
event_topic_obj,
)
@staticmethod
def _extract_metadata_and_payload(message: bytes):
# Extract metadata size
message_view = memoryview(message)
metadata_size = int.from_bytes(message_view[:4], byteorder="big")
# Extract metadata and event
metadata_serialized = message_view[4 : 4 + metadata_size]
event = message_view[4 + metadata_size :]
return metadata_serialized, event
@property
def component_id(self) -> uuid.UUID:
return self._component_id
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.disconnect()
return False # Don't suppress exceptions
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import os
import shutil
import subprocess
import uuid
from collections.abc import AsyncIterator, Awaitable, Callable
from typing import Dict, Optional
from urllib.parse import urlsplit, urlunsplit
import nats
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest, ModelInferResponse
from triton_distributed.icp.request_plane import (
RequestPlane,
get_icp_final_response,
get_icp_request_id,
get_icp_response_error,
get_icp_response_to_uri,
set_icp_component_id,
set_icp_request_id,
set_icp_request_to_uri,
set_icp_response_to_uri,
)
class AsyncModelInferRequestIterator:
def __init__(self, requests: list[ModelInferRequest]) -> None:
self._requests = requests
def __aiter__(self) -> AsyncModelInferRequestIterator:
return self
async def __anext__(self):
if not self._requests:
raise StopAsyncIteration
return self._requests.pop(0)
class AsyncModelInferResponseIterator:
def __init__(
self,
queue: Optional[asyncio.Queue],
raise_on_error=False,
) -> None:
self._queue = queue
self._complete = False
self._raise_on_error = raise_on_error
if not self._queue:
self._complete = True
def __aiter__(self) -> AsyncModelInferResponseIterator:
return self
async def __anext__(self):
if self._complete or self._queue is None:
raise StopAsyncIteration
response = await self._queue.get()
self._complete = get_icp_final_response(response)
error = get_icp_response_error(response)
if error is not None and self._raise_on_error:
raise error
return response
def cancel(self) -> None:
raise NotImplementedError()
class NatsServer:
def __init__(
self,
port: int = 4223,
store_dir: str = "/tmp/nats_store",
log_dir: str = "logs",
debug: bool = False,
clear_store: bool = True,
dry_run: bool = False,
) -> None:
self._process = None
self.port = port
self.url = f"nats://localhost:{port}"
command = [
"/usr/local/bin/nats-server",
"--jetstream",
"--port",
str(port),
"--store_dir",
store_dir,
]
if debug:
command.extend(["--debug", "--trace"])
if dry_run:
print(command)
return
if clear_store:
shutil.rmtree(store_dir, ignore_errors=True)
if log_dir:
os.makedirs(log_dir, exist_ok=True)
with open(f"{log_dir}/nats_server.stdout.log", "wt") as output_:
with open(f"{log_dir}/nats_server.stderr.log", "wt") as output_err:
process = subprocess.Popen(
command,
stdin=subprocess.DEVNULL,
stdout=output_,
stderr=output_err,
)
self._process = process
else:
process = subprocess.Popen(
command,
stdin=subprocess.DEVNULL,
)
self._process = process
def __del__(self):
if self._process:
self._process.terminate()
self._process.kill()
self._process.wait()
class NatsRequestPlane(RequestPlane):
@property
def component_id(self):
return self._component_id
@property
def response_uri(self):
return self._response_uri
async def close(self):
if self._nats_client:
await self._nats_client.close()
def __del__(self):
if self._event_loop and not self._event_loop.is_closed():
self._event_loop.run_until_complete(self.close())
def __init__(
self,
request_plane_uri: str = "nats://localhost:4222",
component_id: Optional[uuid.UUID] = None,
) -> None:
self._request_plane_uri = request_plane_uri
self._component_id = component_id if component_id else uuid.uuid1()
self._response_stream_name = f"component-{self._component_id}-response"
split_uri = urlsplit(self._request_plane_uri)._asdict()
split_uri["path"] = self._response_stream_name
self._response_uri = str(urlunsplit(split_uri.values()))
self._model_streams: Dict[
tuple[str, str], # model_name, model_version
tuple[
str, # stream_name
Optional[nats.js.JetStreamContext.PullSubscription], # general requests
Optional[nats.js.JetStreamContext.PullSubscription], # direct requests
],
] = {}
self._posted_requests: Dict[
uuid.UUID, # request id
tuple[
Optional[asyncio.Queue], # response queue
Optional[Callable[[ModelInferResponse], None | Awaitable[None]]],
Optional[Callable[[ModelInferResponse], Awaitable[None]]],
],
] = {}
self._jet_stream: Optional[nats.js.JetStreamContext] = None
self._event_loop: Optional[asyncio.AbstractEventLoop] = None
def _replace_special_chars(self, stream_name):
return stream_name.replace(".", "-")
async def _get_model_stream(
self, model_name: str, model_version: str, subscribe: bool
) -> tuple[
str,
Optional[nats.js.JetStreamContext.PullSubscription],
Optional[nats.js.JetStreamContext.PullSubscription],
]:
if self._jet_stream is None:
raise ValueError(
"Failed to get model stream: NATS Jetstream not connected!"
)
if (model_name, model_version) in self._model_streams:
return self._model_streams[(model_name, model_version)]
model_stream_name = self._replace_special_chars(
f"model-{model_name}-{model_version}"
)
await self._jet_stream.add_stream(
name=model_stream_name,
subjects=[model_stream_name, model_stream_name + ".*"],
retention=nats.js.api.RetentionPolicy.WORK_QUEUE,
)
general_requests = None
directed_requests = None
if subscribe:
general_requests = await self._jet_stream.pull_subscribe(
subject=model_stream_name,
stream=model_stream_name,
durable=model_stream_name,
)
directed_subject = f"{model_stream_name}.{self._component_id}"
directed_durable = f"{model_stream_name}-{self._component_id}"
directed_requests = await self._jet_stream.pull_subscribe(
subject=directed_subject,
stream=model_stream_name,
durable=directed_durable,
)
return self._model_streams.setdefault(
(model_name, model_version),
(model_stream_name, general_requests, directed_requests),
)
async def _response_callback(self, message):
await message.ack()
response = ModelInferResponse()
response.ParseFromString(message.data)
request_id = get_icp_request_id(response)
if request_id in self._posted_requests:
response_queue, handler, async_handler = self._posted_requests[request_id]
if get_icp_final_response(response):
del self._posted_requests[request_id]
if response_queue:
return await response_queue.put(response)
if async_handler is not None:
return await async_handler(response)
if handler is not None:
return handler(response)
async def connect(self):
self._nats_client = await nats.connect(self._request_plane_uri)
self._jet_stream = self._nats_client.jetstream()
self._event_loop = asyncio.get_event_loop()
await self._jet_stream.add_stream(
name=self._response_stream_name,
subjects=[self._response_stream_name],
retention=nats.js.api.RetentionPolicy.WORK_QUEUE,
)
await self._jet_stream.subscribe(
self._response_stream_name,
cb=self._response_callback,
durable=self._response_stream_name,
stream=self._response_stream_name,
)
async def pull_requests(
self,
model_name: str,
model_version: str,
number_requests: int = 1,
timeout: Optional[float] = None,
) -> AsyncIterator[ModelInferRequest]:
# Note directed requests and general requests are
# pulled in parallel. Directed requests are consumed
# first. If there are more requests than the batch size
# then extra requests are scheduled for redlivery via nak
requests: list[ModelInferRequest] = []
acks = []
_, general, directed = await self._get_model_stream(
model_name, model_version, subscribe=True
)
tasks = [
asyncio.create_task(
subscription.fetch(batch=number_requests, timeout=timeout)
)
for subscription in [directed, general]
if subscription
]
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
for task in tasks:
if task not in done:
continue
try:
for message in task.result():
if len(requests) < number_requests:
request = ModelInferRequest()
request.ParseFromString(message.data)
requests.append(request)
acks.append(message.ack())
else:
acks.append(message.nak())
except nats.errors.TimeoutError:
continue
asyncio.gather(*acks)
return AsyncModelInferRequestIterator(requests)
@staticmethod
async def _single_response(response: ModelInferResponse):
yield response
async def post_response(
self,
request: ModelInferRequest,
responses: AsyncIterator[ModelInferResponse] | ModelInferResponse,
):
if self._jet_stream is None:
raise ValueError("Failed to post response: NATS Jetstream not connected!")
request_id = get_icp_request_id(request)
if request_id is None:
raise ValueError("ICP request must have request id")
response_to_uri = get_icp_response_to_uri(request)
if not response_to_uri:
raise ValueError("Attempting to send a response when non requested")
parsed = urlsplit(response_to_uri)
response_stream = parsed.path.replace("/", "")
if isinstance(responses, ModelInferResponse):
responses = NatsRequestPlane._single_response(responses)
async for response in responses:
set_icp_request_id(response, request_id)
response.model_name = request.model_name
response.model_version = request.model_version
response.id = request.id
set_icp_component_id(response, self._component_id)
await self._jet_stream.publish(
response_stream,
response.SerializeToString(),
stream=response_stream,
)
async def post_request(
self,
request: ModelInferRequest,
*,
component_id: Optional[uuid.UUID] = None,
response_iterator: bool = False,
response_handler: Optional[
Callable[[ModelInferResponse], None | Awaitable[None]]
] = None,
) -> AsyncIterator[ModelInferResponse]:
if self._jet_stream is None:
raise ValueError("Failed to post request: NATS Jetstream not connected!")
if response_iterator and response_handler:
raise ValueError(
"Can only specify either response handler or response iterator"
)
async_response_handler = None
response_queue = None
if response_handler or response_iterator:
request_id = get_icp_request_id(request)
if request_id is None:
request_id = uuid.uuid1()
set_icp_request_id(request, request_id)
set_icp_response_to_uri(request, self._response_uri)
set_icp_component_id(request, self._component_id)
async_response_handler = (
response_handler
if asyncio.iscoroutinefunction(response_handler)
else None
)
response_queue = None
if response_iterator:
response_queue = asyncio.Queue()
self._posted_requests[request_id] = (
response_queue,
response_handler,
async_response_handler,
)
stream_name, _, _ = await self._get_model_stream(
request.model_name, request.model_version, subscribe=False
)
subject = stream_name
if component_id:
subject += f".{component_id}"
split_uri = urlsplit(self._request_plane_uri)._asdict()
split_uri["path"] = subject
set_icp_request_to_uri(request, str(urlunsplit(split_uri.values())))
await self._jet_stream.publish(
subject,
request.SerializeToString(),
stream=stream_name,
)
return AsyncModelInferResponseIterator(response_queue)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment