Commit 6d2abdba authored by Blazej's avatar Blazej Committed by GitHub
Browse files

feat: Add event plane


Signed-off-by: default avatarPiotr Marcinkiewicz <piotrm@nvidia.com>
Co-authored-by: default avatarPiotr Marcinkiewicz <piotrm@nvidia.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent a48d932e
...@@ -17,6 +17,7 @@ fastapi==0.115.6 ...@@ -17,6 +17,7 @@ fastapi==0.115.6
ftfy ftfy
grpcio-tools==1.66.0 grpcio-tools==1.66.0
httpx httpx
msgspec
mypy mypy
numpy numpy
opentelemetry-api opentelemetry-api
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Event Plane example
A basic example that demonstrates how to use the Event Plane API to create an event plane, register an event, and trigger the event.
## Code overview
### Using context managers (recommended)
```python
async def example_with_context_managers():
server_url = "tls://localhost:4222"
component_id = uuid.uuid4()
async with NatsEventPlane(server_url, component_id) as plane:
received_events = []
async def callback(event):
print(event)
received_events.append(event)
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"my_payload"
# Subscribe using context manager
async with await plane.subscribe(callback, event_topic=event_topic, event_type=event_type):
# Publish event
await plane.publish(event, event_type, event_topic)
# Allow time for message to propagate
await asyncio.sleep(2)
```
### Manual resource management
#### 1) Initialize NATS server and create an event plane
```python
server_url = "tls://localhost:4222" # Optional, default is nats://localhost:4222
component_id = uuid.uuid4() # Optional, component_id will be generated if not given
plane = NatsEventPlane(server_url, component_id)
await plane.connect()
```
#### 2) Define the callback function for receiving events
```python
received_events = []
async def callback(event):
print(event)
received_events.append(event)
```
#### 3) Prepare the event event_topic, event type, and event payload
```python
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"my_payload"
```
#### 4) Subscribe to the event event_topic and type and register the callback function
```python
subscription = await plane.subscribe(callback, event_topic=event_topic, event_type=event_type)
```
#### 5) Publish the event
```python
await plane.publish(event, event_type, event_topic)
```
#### 6) Clean up resources
```python
# Unsubscribe when done
await subscription.unsubscribe()
# Disconnect from NATS server
await plane.disconnect()
```
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from triton_distributed.icp.nats_event_plane import DEFAULT_EVENTS_PORT
def parse_args(args=None):
parser = argparse.ArgumentParser(description="Event Plane Example")
parser.add_argument(
"--nats-port",
type=int,
default=DEFAULT_EVENTS_PORT,
help="Nats server port",
)
parser.add_argument(
"--publisher-count",
type=int,
default=1,
help="Number of publishers to deploy",
)
parser.add_argument(
"--subscriber-count",
type=int,
default=10,
help="Number of subscribers to deploy",
)
args = parser.parse_args(args)
return args
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
from triton_distributed.icp.nats_event_plane import (
EventTopic,
NatsEventPlane,
compose_nats_url,
)
async def single_publisher_subscriber_example():
# async with aclosing(event_plane()) as event_plane_instance:
# event_plane_instance = await anext(event_plane)
server_url = compose_nats_url()
component_id = str(uuid.uuid4())
plane = NatsEventPlane(server_url, component_id)
await plane.connect()
received_events = []
async def callback(event):
print(event)
print(event.payload)
print(event.typed_payload(bytes))
received_events.append(event)
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"my_payload"
await plane.subscribe(callback, event_topic=event_topic, event_type=event_type)
await plane.publish(event, event_type, event_topic)
# Allow time for message to propagate
await asyncio.sleep(3)
print(f"received_events: {received_events}")
# assert received_events[0][0].event_id == event.event_id
await plane.disconnect()
if __name__ == "__main__":
asyncio.run(single_publisher_subscriber_example())
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import uuid
from triton_distributed.icp.nats_event_plane import (
EventTopic,
NatsEventPlane,
compose_nats_url,
)
async def main(args):
server_url = compose_nats_url()
event_plane = NatsEventPlane(server_url, args.component_id)
await event_plane.connect()
try:
event_topic = (
EventTopic(args.event_topic.split(".")) if args.event_topic else None
)
event = args.payload.encode()
await event_plane.publish(event, args.event_type, event_topic)
print(f"Published event from publisher {args.event_topic}")
finally:
await event_plane.disconnect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event publisher script")
parser.add_argument(
"--component-id",
type=uuid.UUID,
default=uuid.uuid4(),
help="Component ID (UUID)",
)
parser.add_argument(
"--event-topic",
type=str,
default=None,
help="Event EventTopic to subscribe to (comma-separated for multiple levels)",
)
parser.add_argument(
"--event-type", type=str, default="test_event", help="Event type"
)
parser.add_argument(
"--payload",
type=str,
default="test_payload",
help="Payload to be published with event.",
)
args = parser.parse_args()
asyncio.run(main(args))
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import uuid
from triton_distributed.icp.nats_event_plane import (
EventTopic,
NatsEventPlane,
compose_nats_url,
)
async def main(args):
server_url = compose_nats_url()
event_plane = NatsEventPlane(server_url, uuid.uuid4())
async def callback(received_event):
print(
f"""
Subscriber {args.subscriber_id}
received event: {received_event.event_id}
event payload: {received_event.payload.tobytes().decode("utf-8")}
event.topic: {received_event.event_topic}
event.type: {received_event.event_type}
event.component_id: {received_event.component_id}
event.timestamp: {received_event.timestamp}
"""
)
await event_plane.connect()
try:
event_topic = (
EventTopic(args.event_topic.split(".")) if args.event_topic else None
)
print(f"Subscribing to event_topic: {args.event_topic}")
await event_plane.subscribe(
callback,
event_topic=event_topic,
event_type=args.event_type,
component_id=args.component_id,
)
print(
f"Subscriber {args.subscriber_id} is listening on event_topic {event_topic} with event type '{args.event_type or 'all'}' "
+ f"component ID '{args.component_id}'"
)
while True:
await asyncio.sleep(5) # Keep the subscriber running
print(f"Subscriber {args.subscriber_id} is still running")
finally:
await event_plane.disconnect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event subscriber script")
parser.add_argument(
"--event-topic",
type=str,
default=None,
help="Event EventTopic to subscribe to (comma-separated for multiple levels)",
)
parser.add_argument(
"--event-type",
type=str,
default=None,
help="Event type to filter (default: None for all types)",
)
parser.add_argument(
"--component-id",
type=uuid.UUID,
default=None,
help="Component ID (UUID) for the subscriber",
)
args = parser.parse_args()
asyncio.run(main(args))
...@@ -14,6 +14,16 @@ ...@@ -14,6 +14,16 @@
# limitations under the License. # limitations under the License.
from triton_distributed.icp.data_plane import DataPlane as DataPlane from triton_distributed.icp.data_plane import DataPlane as DataPlane
from triton_distributed.icp.event_plane import Event as Event
from triton_distributed.icp.event_plane import EventPlane as EventPlane
from triton_distributed.icp.event_plane import EventTopic as EventTopic
from triton_distributed.icp.nats_event_plane import (
DEFAULT_EVENTS_HOST as DEFAULT_EVENTS_HOST,
)
from triton_distributed.icp.nats_event_plane import (
DEFAULT_EVENTS_PORT as DEFAULT_EVENTS_PORT,
)
from triton_distributed.icp.nats_event_plane import NatsEventPlane as NatsEventPlane
from triton_distributed.icp.nats_request_plane import ( from triton_distributed.icp.nats_request_plane import (
NatsRequestPlane as NatsRequestPlane, NatsRequestPlane as NatsRequestPlane,
) )
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import re
import uuid
from abc import abstractmethod
from datetime import datetime
from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Type, Union
EVENT_TOPIC_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$")
def _validate_topics(topics: List[str]) -> bool:
"""
Checks if all strings in the list are alphanumeric and can contain underscores (_) and hyphens (-).
:param subjects: List of strings to validate
:return: True if all strings are valid, False otherwise
"""
pattern = EVENT_TOPIC_PATTERN
return all(pattern.match(topic) for topic in topics)
@dataclasses.dataclass
class EventTopic:
"""Event event_topic class for identifying event streams."""
event_topic: str
def __init__(self, event_topic: Union[List[str], str]):
"""Initialize the event_topic.
Args:
event_topic (Union[List[str], str]): The event_topic as a list of strings or a single string. Strings should be alphanumeric + underscore and '-' characters only. The list forms a hierarchy of topics.
"""
if isinstance(event_topic, str):
if "." in event_topic:
event_topic_list = event_topic.split(".")
else:
event_topic_list = [event_topic]
else:
event_topic_list = event_topic
if not _validate_topics(event_topic_list):
raise ValueError(
"Invalid event_topic. Only alphanumeric characters, underscores, and hyphens are allowed."
)
event_topic_string = ".".join(event_topic_list)
self.event_topic = event_topic_string
def __str__(self):
return self.event_topic
class Event:
"""Event class for representing events."""
@property
@abstractmethod
def event_id(self) -> uuid.UUID:
pass
@property
@abstractmethod
def event_type(self) -> str:
pass
@property
@abstractmethod
def timestamp(self) -> datetime:
pass
@property
@abstractmethod
def component_id(self) -> uuid.UUID:
pass
@property
@abstractmethod
def event_topic(self) -> Optional[EventTopic]:
pass
@property
@abstractmethod
def payload(self) -> bytes:
pass
@abstractmethod
def typed_payload(self, payload_type: Optional[Type | str] = None) -> Any:
pass
class EventSubscription(AsyncIterator[Event]):
@abstractmethod
async def __anext__(self) -> Event:
pass
@abstractmethod
def __aiter__(self):
return self
@abstractmethod
def unsubscribe(self):
pass
class EventPlane:
"""EventPlane interface for publishing and subscribing to events."""
@abstractmethod
async def connect(self):
"""Connect to the event plane."""
pass
@abstractmethod
async def publish(
self,
event: Union[bytes, Any],
event_type: str,
event_topic: Optional[EventTopic],
) -> Event:
"""Publish an event to the event plane.
Args:
event (Union[bytes, Any]): Event payload
event_type (str): Event type
event_topic (Optional[EventTopic]): Event event_topic
"""
pass
@abstractmethod
async def subscribe(
self,
callback: Callable[[Event], Awaitable[None]],
event_topic: Optional[EventTopic] = None,
event_type: Optional[str] = None,
component_id: Optional[uuid.UUID] = None,
) -> EventSubscription:
"""Subscribe to events on the event plane.
Args:
callback (Callable[[bytes, bytes], Awaitable[None]]): Callback function to be called when an event is received
event_topic (Optional[EventTopic]): Event event_topic
event_type (Optional[str]): Event type
component_id (Optional[uuid.UUID]): Component ID
"""
pass
@abstractmethod
async def disconnect(self):
"""Disconnect from the event plane."""
pass
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import datetime
import logging
import os
import uuid
from typing import Any, Awaitable, Callable, List, Optional, Union
import msgspec
import nats
from triton_distributed.icp import EventTopic
from triton_distributed.icp.event_plane import Event, EventSubscription
from triton_distributed.icp.on_demand_event import (
EventMetadata,
OnDemandEvent,
_serialize_metadata,
)
logger = logging.getLogger(__name__)
DEFAULT_EVENTS_PORT = int(os.getenv("DEFAULT_EVENTS_PORT", 4222))
DEFAULT_EVENTS_HOST = os.getenv("DEFAULT_EVENTS_HOST", "localhost")
DEFAULT_EVENTS_PROTOCOL = os.getenv("DEFAULT_EVENTS_PROTOCOL", "nats")
DEFAULT_CONNECTION_TIMEOUT = int(os.getenv("DEFAULT_CONNECTION_TIMEOUT", 30))
EVENT_PLANE_NATS_PREFIX = "event_plane_nats_v1"
def compose_nats_uri(
protocol: str = DEFAULT_EVENTS_PROTOCOL,
host: str = DEFAULT_EVENTS_HOST,
port: int = DEFAULT_EVENTS_PORT,
) -> str:
"""Compose a NATS URL from components.
Args:
protocol: The protocol to use (tls or nats). Defaults to DEFAULT_EVENTS_PROTOCOL.
host: The host to connect to. Defaults to DEFAULT_EVENTS_HOST.
port: The port to connect to. Defaults to DEFAULT_EVENTS_PORT.
Returns:
str: The composed NATS URL
"""
return f"{protocol}://{host}:{port}"
class NatsEventSubscription(EventSubscription):
def __init__(
self,
nc_sub: nats.aio.subscription.Subscription,
nats_connection: Any,
subject: str,
topic: EventTopic,
):
self._nc_sub: Optional[nats.aio.subscription.Subscription] = nc_sub
self._nats = nats_connection
self._subject = subject
self._topic = topic
self._unsubscribe_event: asyncio.Event = asyncio.Event()
async def __anext__(self):
if self._nc_sub is None:
raise StopAsyncIteration
if not self._nats.is_connected:
if self._error is not None:
raise RuntimeError(
f"NATS connection error: {self._error}"
) from self._error
else:
raise RuntimeError("NATS connection failure.")
else:
failure_task = asyncio.create_task(self._nats.wait_for_failure())
next_task = asyncio.create_task(self._nc_sub.next_msg())
_ = await asyncio.wait(
[next_task, failure_task], return_when=asyncio.FIRST_COMPLETED
)
if failure_task.done():
logger.warning("NATS connection failure.")
try:
next_task.cancel()
await next_task
except asyncio.CancelledError:
pass
raise RuntimeError("NATS connection failure.") from failure_task.exception()
else:
try:
failure_task.cancel()
await failure_task
except asyncio.CancelledError:
pass
msg = next_task.result()
metadata, event_payload = NatsEventPlane._extract_metadata_and_payload(
msg.data
)
event = OnDemandEvent(event_payload, metadata)
return event
def __aiter__(self):
return self
async def unsubscribe(self):
if self._nc_sub is None:
return
if self._nats.is_connected():
await self._nc_sub.unsubscribe()
self._nc_sub = None
else:
logger.warning("NATS not connected. Cannot unsubscribe.")
@property
def subject(self):
return self._subject
@property
def topic(self):
return self._topic
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.unsubscribe()
return False # Don't suppress exceptions
class NatsEventPlane:
"""EventPlane implementation using NATS."""
def __init__(
self,
server_uri: str = compose_nats_uri(),
component_id: Optional[uuid.UUID] = uuid.uuid4(),
run_callback_in_parallel: bool = False,
):
"""Initialize the NATS event plane.
Args:
server_uri: URI of the NATS server. If None, will be composed using environment variables.
component_id: Component ID.
"""
self._run_callback_in_parallel = run_callback_in_parallel
if server_uri is None:
server_uri = compose_nats_uri()
self._server_uri = server_uri
if component_id is None:
component_id = uuid.uuid4()
self._component_id = component_id
self._nc = nats.NATS()
self._error: Optional[Exception] = None
self._connected = False
self._failure_event: Optional[asyncio.Event] = None
async def wait_for_failure(self):
"""Wait for a failure event."""
if self._failure_event is not None:
await self._failure_event.wait()
raise RuntimeError("NATS connection failure.") from self._error
else:
raise RuntimeError("NATS connection failure event is None")
def is_connected(self):
return self._connected
async def connect(self):
"""Connect to the NATS server."""
if self._connected:
return
async def error_cb(e):
logger.warning("NATS error: %s", e)
self._error = e
if self._failure_event is not None:
self._failure_event.set()
self._failure_event = asyncio.Event()
else:
logger.error(f"NATS connection failure event is None for error {e}")
async def reconnected_cb():
logger.debug("NATS reconnected")
self._connected = True
async def disconnected_cb():
logger.debug("NATS disconnected")
self._connected = False
async def closed_cb():
logger.debug("NATS closed")
self._connected = False
self._failure_event = asyncio.Event()
try:
async with asyncio.timeout(DEFAULT_CONNECTION_TIMEOUT):
logger.debug(f"Connecting to NATS server: {self._server_uri}")
connect_task = asyncio.create_task(
self._nc.connect(
self._server_uri,
error_cb=error_cb,
reconnected_cb=reconnected_cb,
disconnected_cb=disconnected_cb,
closed_cb=closed_cb,
)
)
failed_task = asyncio.create_task(self.wait_for_failure())
await asyncio.wait(
[connect_task, failed_task], return_when=asyncio.FIRST_COMPLETED
)
if failed_task.done():
try:
connect_task.cancel()
await connect_task
except asyncio.CancelledError:
pass
raise RuntimeError(
"NATS connection failure."
) from failed_task.exception()
else:
try:
failed_task.cancel()
await failed_task
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
raise RuntimeError(
f"NATS connection timeout {DEFAULT_CONNECTION_TIMEOUT} reached."
)
logger.debug(f"Connected to NATS server: {self._server_uri}")
self._connected = True
async def publish(
self,
payload: bytes | Any,
event_type: Optional[str] = None,
event_topic: Optional[EventTopic] = None,
timestamp: Optional[datetime.datetime] = datetime.datetime.now(datetime.UTC),
event_id: Optional[uuid.UUID] = uuid.uuid4(),
) -> Event:
"""Publish an event to the NATS server.
Args:
payload: Event payload.
event_type: Type of the event.
event_topic: EventTopic of the event.
"""
if not self._connected:
if self._error:
raise RuntimeError(
f"NATS connection error: {self._error}"
) from self._error
else:
raise RuntimeError("NATS not connected.")
if timestamp is None:
timestamp = datetime.datetime.now(datetime.UTC)
if event_id is None:
event_id = uuid.uuid4()
event_metadata = EventMetadata(
event_id=event_id,
event_topic=event_topic,
event_type=event_type if event_type else str(type(payload).__name__),
timestamp=timestamp,
component_id=self._component_id,
)
metadata_serialized = _serialize_metadata(event_metadata)
metadata_size = len(metadata_serialized).to_bytes(4, byteorder="big")
# Concatenate metadata size, metadata, and event payload
if isinstance(payload, bytes):
message = metadata_size + metadata_serialized + payload
else:
message = metadata_size + metadata_serialized + msgspec.json.encode(payload)
subject = self._compose_publish_subject(event_metadata)
await self._nc.publish(subject, message)
event_with_metadata = OnDemandEvent(
payload, metadata_serialized, event_metadata
)
return event_with_metadata
async def subscribe(
self,
callback: Optional[Callable[[Event], Awaitable[None]]] = None,
event_topic: Optional[EventTopic | str | List[str]] = None,
event_type: Optional[str] = "*",
component_id: Optional[uuid.UUID] = None,
) -> EventSubscription:
"""Subscribe to events on the NATS server.
Args:
callback: Callback function to be called when an event is received.
event_topic: Event event_topic.
event_type: Event type.
component_id: Component ID.
"""
if not self._connected:
if self._error:
raise RuntimeError(
f"NATS connection error: {self._error}"
) from self._error
else:
raise RuntimeError("NATS not connected.")
async def _message_handler(msg):
metadata, event_payload = NatsEventPlane._extract_metadata_and_payload(
msg.data
)
event = OnDemandEvent(event_payload, metadata)
async def wrapper():
if callback is not None:
await callback(event) # Ensure it's a proper coroutine
if self._run_callback_in_parallel:
if callback is not None:
asyncio.create_task(wrapper()) # Run in parallel
else:
if callback is not None:
await callback(event) # Await normally
subject_str, topic = self._compose_subscribe_subject(
event_topic, event_type, component_id
)
_cb = _message_handler if callback is not None else None
sub = await self._nc.subscribe(subject_str, cb=_cb)
event_sub = NatsEventSubscription(sub, self, subject_str, topic)
return event_sub
async def disconnect(self):
"""Disconnect from the NATS server."""
if not self._connected:
return
await self._nc.close()
self._error = RuntimeError("NATS connection closed by disconnect.")
if self._failure_event is not None:
self._failure_event.set()
self._connected = False
def _compose_publish_subject(self, event_metadata: EventMetadata):
return f"{EVENT_PLANE_NATS_PREFIX}.{event_metadata.event_type}.{event_metadata.component_id}.{str(event_metadata.event_topic) + '.' if event_metadata.event_topic else ''}trunk"
def _compose_subscribe_subject(
self,
event_topic: Optional[Union[EventTopic, str, List[str]]] = None,
event_type: Optional[str] = None,
component_id: Optional[uuid.UUID] = None,
):
if isinstance(event_topic, str) or isinstance(event_topic, list):
event_topic_obj = EventTopic(event_topic)
else:
event_topic_obj = event_topic
return (
f"{EVENT_PLANE_NATS_PREFIX}.{event_type or '*'}.{component_id or '*'}.{str(event_topic_obj) + '.' if event_topic else ''}>",
event_topic_obj,
)
@staticmethod
def _extract_metadata_and_payload(message: bytes):
# Extract metadata size
message_view = memoryview(message)
metadata_size = int.from_bytes(message_view[:4], byteorder="big")
# Extract metadata and event
metadata_serialized = message_view[4 : 4 + metadata_size]
event = message_view[4 + metadata_size :]
return metadata_serialized, event
@property
def component_id(self) -> uuid.UUID:
return self._component_id
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.disconnect()
return False # Don't suppress exceptions
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import dataclasses
import uuid
from datetime import datetime
from typing import Any, Optional, Type
import msgspec
from triton_distributed.icp.event_plane import Event, EventTopic
@dataclasses.dataclass
class EventMetadata:
"""
Class keeps metadata of an event.
"""
event_id: uuid.UUID
event_type: str
timestamp: datetime
component_id: uuid.UUID
event_topic: Optional[EventTopic] = None
def _deserialize_metadata(event_metadata_serialized: bytes):
event_metadata_dict = msgspec.json.decode(event_metadata_serialized)
topic_meta = event_metadata_dict["event_topic"]
topic_list = topic_meta["event_topic"].split(".") if topic_meta else []
topic_obj = EventTopic(topic_list)
metadata = EventMetadata(
**{
**event_metadata_dict,
"event_topic": topic_obj,
"event_id": uuid.UUID(event_metadata_dict["event_id"]),
"component_id": uuid.UUID(event_metadata_dict["component_id"]),
"timestamp": datetime.fromisoformat(event_metadata_dict["timestamp"]),
}
)
return metadata
def _serialize_metadata(event_metadata: EventMetadata) -> bytes:
def hook(obj):
if isinstance(obj, uuid.UUID):
return str(obj)
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, EventTopic):
return list(obj.event_topic.split("."))
else:
raise NotImplementedError(f"Type {type(obj)} is not serializable.")
json_string = msgspec.json.encode(event_metadata, enc_hook=hook)
return json_string
def _get_type(type_name: str):
# Check in builtins for the type
builtin_type = getattr(builtins, type_name, None)
if builtin_type and isinstance(builtin_type, type):
return builtin_type
# Check in globals for the type
global_type = globals().get(type_name)
if global_type and isinstance(global_type, type):
return global_type
return None
class OnDemandEvent(Event):
"""LazyEvent class for representing events."""
def __init__(
self,
payload: bytes,
event_metadata_serialized: bytes,
event_metadata: Optional[EventMetadata] = None,
):
"""Initialize the event.
Args:
event_metadata (EventMetadata): Event metadata
event (bytes): Event payload
"""
self._payload = payload
self._event_metadata_serialized = event_metadata_serialized
self._event_metadata = event_metadata
@property
def _metadata(self):
if not self._event_metadata:
self._event_metadata = _deserialize_metadata(
self._event_metadata_serialized
)
return self._event_metadata
@property
def event_id(self) -> uuid.UUID:
return self._metadata.event_id
@property
def event_type(self) -> str:
return self._metadata.event_type
@property
def timestamp(self) -> datetime:
return self._metadata.timestamp
@property
def component_id(self) -> uuid.UUID:
return self._metadata.component_id
@property
def event_topic(self) -> Optional[EventTopic]:
return self._metadata.event_topic
@property
def payload(self) -> bytes:
return self._payload
def typed_payload(self, payload_type: Optional[Type | str] = None) -> Any:
if payload_type is None:
payload_type = self.event_type
if isinstance(payload_type, str):
payload_type = _get_type(payload_type)
if payload_type is not None and payload_type is not bytes:
try:
return msgspec.json.decode(self._payload, type=payload_type)
except Exception as e:
raise ValueError(
f"Unable to convert payload {self._payload!r} to type {payload_type} from event type {self.event_type}"
) from e
elif payload_type is bytes:
return bytes(self._payload)
else:
raise ValueError(
f"Unable to convert payload {self._payload!r} to type {payload_type} from event type {self.event_type}"
)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from datetime import datetime
import pytest
from triton_distributed.icp.nats_event_plane import (
EventMetadata,
EventTopic,
NatsEventPlane,
)
pytestmark = pytest.mark.pre_merge
class TestEventTopic:
def test_from_string(self):
topic_str = "level1"
event_topic = EventTopic(topic_str)
assert event_topic.event_topic == topic_str
def test_to_string(self):
event_topic = EventTopic(["level1", "level2"])
assert str(event_topic) == "level1.level2"
class TestEvent:
@pytest.fixture
def sample_event_metadata(self):
event_topic = EventTopic("test.event_topic")
return EventMetadata(
event_id=uuid.uuid4(),
event_topic=event_topic,
event_type="test_event",
timestamp=datetime.utcnow(),
component_id=uuid.uuid4(),
)
class TestEventPlaneNats:
@pytest.fixture
def event_plane_instance(self):
server_url = "tls://localhost:4222"
component_id = uuid.uuid4()
return NatsEventPlane(server_url, component_id)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import dataclasses
import uuid
from typing import List
import pytest
from utils import event_plane, nats_server
from triton_distributed.icp import Event, EventTopic, NatsEventPlane
pytestmark = pytest.mark.pre_merge
@pytest.mark.asyncio
class TestEventPlaneFunctional:
@pytest.mark.asyncio
async def test_single_publisher_subscriber(self, nats_server, event_plane):
print(f"Print loop test: {id(asyncio.get_running_loop())}")
received_events: List[Event] = []
async def callback(event):
received_events.append(event)
print(event)
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"test_payload"
await event_plane.subscribe(
callback, event_topic=event_topic, event_type=event_type
)
event_metadata = await event_plane.publish(event, event_type, event_topic)
# Allow time for message to propagate
await asyncio.sleep(2)
assert len(received_events) == 1
assert received_events[0].event_id == event_metadata.event_id
@pytest.mark.asyncio
async def test_single_publisher_subscriber_iterator(self, nats_server, event_plane):
print(f"Print loop test: {id(asyncio.get_running_loop())}")
received_events: List[Event] = []
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"test_payload"
subscription = await event_plane.subscribe(
event_topic=event_topic, event_type=event_type
)
event_metadata = await event_plane.publish(
event, event_topic=event_topic, event_type=event_type
)
# Allow time for message to propagate
await asyncio.sleep(2)
async for x in subscription:
print(x.timestamp)
print(x.event_id)
print(x.event_type)
print(x.event_topic)
print(x.payload)
received_events.append(x)
break
assert len(received_events) == 1
assert received_events[0].event_id == event_metadata.event_id
@pytest.mark.asyncio
async def test_default_subscription(self, nats_server, event_plane):
print(f"Print loop test: {id(asyncio.get_running_loop())}")
received_events: List[Event] = []
event = b"test_payload"
subscription = await event_plane.subscribe()
event_metadata = await event_plane.publish(
event,
)
# Allow time for message to propagate
await asyncio.sleep(2)
async for x in subscription:
print(x.timestamp)
print(x.event_id)
print(x.event_type)
print(x.event_topic)
print(x.payload)
received_events.append(x)
break
assert len(received_events) == 1
assert received_events[0].event_id == event_metadata.event_id
@pytest.mark.asyncio
async def test_custom_type(self, nats_server, event_plane):
print(f"Print loop test: {id(asyncio.get_running_loop())}")
received_events: List[Event] = []
@dataclasses.dataclass
class MyEvent:
test: str
index: int
event = MyEvent("hello", 0)
subscription = await event_plane.subscribe()
event_metadata = await event_plane.publish(
event,
)
# Allow time for message to propagate
await asyncio.sleep(2)
async for x in subscription:
print(x.timestamp)
print(x.event_id)
print(x.event_type)
print(x.event_topic)
print(x.payload)
print(x.typed_payload(MyEvent))
received_events.append(x)
break
assert len(received_events) == 1
assert received_events[0].event_id == event_metadata.event_id
assert isinstance(received_events[0].typed_payload(MyEvent), type(event))
assert isinstance(received_events[0].typed_payload(dict), dict)
@pytest.mark.asyncio
async def test_one_publisher_multiple_subscribers(self, nats_server):
results_1: List[Event] = []
results_2: List[Event] = []
results_3: List[Event] = []
async def callback_1(event):
results_1.append(event)
async def callback_2(event):
results_2.append(event)
async def callback_3(event):
results_3.append(event)
event_topic = EventTopic(["test"])
event_type = "multi_event"
event = b"multi_payload"
# async with event_plane_context() as event_plane1:
server_url = "tls://localhost:4222"
component_id = uuid.uuid4()
event_plane2 = NatsEventPlane(server_url, component_id)
try:
await event_plane2.connect()
try:
subscription1 = await event_plane2.subscribe(
callback_1, event_topic=event_topic
)
try:
subscription2 = await event_plane2.subscribe(
callback_2, event_topic=event_topic
)
try:
subscription3 = await event_plane2.subscribe(
callback_3, event_type=event_type
)
component_id = uuid.uuid4()
event_plane1 = NatsEventPlane(server_url, component_id)
try:
await event_plane1.connect()
ch1 = EventTopic(["test", "1"])
ch2 = EventTopic(["test", "2"])
await event_plane1.publish(event, event_type, ch1)
await event_plane1.publish(event, event_type, ch2)
# Allow time for message propagation
await asyncio.sleep(2)
assert len(results_1) == 2
assert len(results_2) == 2
assert len(results_3) == 2
finally:
await event_plane1.disconnect()
finally:
await subscription3.unsubscribe()
finally:
await subscription2.unsubscribe()
finally:
await subscription1.unsubscribe()
finally:
await event_plane2.disconnect()
@pytest.mark.asyncio
async def test_context_manager(self, nats_server):
"""Test that context managers properly handle connection/disconnection and subscription/unsubscription."""
received_events: List[Event] = []
event_topic = EventTopic(["test", "event_topic"])
event_type = "test_event"
event = b"test_payload"
# Test successful operation with context managers
async with NatsEventPlane() as plane:
assert plane.is_connected()
async def callback(event):
received_events.append(event)
async with await plane.subscribe(
callback, event_topic=event_topic, event_type=event_type
) as subscription:
assert subscription._nc_sub is not None
event_metadata = await plane.publish(event, event_type, event_topic)
await asyncio.sleep(2) # Allow time for message to propagate
# After subscription context, should be unsubscribed
assert subscription._nc_sub is None
# After plane context, should be disconnected
assert not plane.is_connected()
assert len(received_events) == 1
assert received_events[0].event_id == event_metadata.event_id
# Test error handling in context managers
with pytest.raises(RuntimeError):
async with NatsEventPlane() as plane:
async with await plane.subscribe(
callback, event_topic=event_topic, event_type=event_type
):
raise RuntimeError("Test error")
# Should not reach here
pytest.fail("Should have raised exception")
# Should not reach here
pytest.fail("Should have raised exception")
# Even after error, resources should be cleaned up
assert not plane.is_connected()
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import subprocess
import time
from contextlib import asynccontextmanager
import pytest_asyncio
from triton_distributed.icp import (
DEFAULT_EVENTS_HOST,
DEFAULT_EVENTS_PORT,
NatsEventPlane,
)
logger = logging.getLogger(__name__)
def is_port_in_use(port: int) -> bool:
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
@pytest_asyncio.fixture(loop_scope="session")
async def nats_server():
"""Fixture to start and stop a NATS server."""
process = None
try:
# Raise more intuitive error to developer if port is already in-use.
if is_port_in_use(DEFAULT_EVENTS_PORT):
raise RuntimeError(
f"ERROR: NATS Port {DEFAULT_EVENTS_PORT} already in use. Is a nats-server already running?"
)
# Start NATS server
logger.info("NATS server starting")
process = subprocess.Popen(
[
"nats-server",
"-p",
str(DEFAULT_EVENTS_PORT),
"-addr",
DEFAULT_EVENTS_HOST,
],
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
while not is_port_in_use(DEFAULT_EVENTS_PORT):
logger.debug("Waiting for NATS server to start...")
time.sleep(0.2)
logger.info("NATS server started")
yield process
finally:
# Stop the NATS server
if process:
logger.debug("Closing NATS server")
process.terminate()
# communicate() ensures we consume all stdout/stderr so they can close
out, err = process.communicate()
# If you want to log them:
logger.debug("NATS server stdout: %s", out.decode())
logger.debug("NATS server stderr: %s", err.decode())
if process.stdout:
process.stdout.close()
if process.stderr:
process.stderr.close()
# Stop the NATS server
process.wait()
@asynccontextmanager
async def event_plane_context():
# with nats_server_context() as server:
print(f"Print loop plane context: {id(asyncio.get_running_loop())}")
plane = NatsEventPlane()
await plane.connect()
yield plane
await plane.disconnect()
@pytest_asyncio.fixture(loop_scope="function")
async def event_plane():
print(f"Print loop plane: {id(asyncio.get_running_loop())}")
plane = NatsEventPlane()
await plane.connect()
yield plane
await plane.disconnect()
...@@ -83,6 +83,9 @@ markers = [ ...@@ -83,6 +83,9 @@ markers = [
line-length = 88 line-length = 88
indent-width = 4 indent-width = 4
[tool.ruff.lint.extend-per-file-ignores]
"icp/tests/**/test_*.py" = ["F811", "F401"]
[tool.mypy] [tool.mypy]
# --disable-error-code: WAR large set of errors due to mypy not being run # --disable-error-code: WAR large set of errors due to mypy not being run
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment