"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "0abebe388404be84bb38f2c2ed32198635941bcc"
Unverified Commit e61f1c8a authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

chore: Remove nats-py dependency (#1387)

parent 373f1f38
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
nats-py
......@@ -18,10 +18,7 @@ import asyncio
from contextlib import asynccontextmanager
from typing import ClassVar, Optional
from nats.aio.client import Client as NATS
from nats.errors import Error as NatsError
from nats.js.client import JetStreamContext
from nats.js.errors import NotFoundError
from dynamo._core import NatsQueue
class NATSQueue:
......@@ -34,15 +31,7 @@ class NATSQueue:
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
self.nats_url = nats_server
self._nc: Optional[NATS] = None
self._js: Optional[JetStreamContext] = None
# TODO: check if this is needed
# Sanitize stream_name to remove path separators
self._stream_name = stream_name.replace("/", "_").replace("\\", "_")
self._subject = f"{self._stream_name}.*"
self.dequeue_timeout = dequeue_timeout
self._subscriber: Optional[JetStreamContext.PullSubscription] = None
self.nats_q = NatsQueue(stream_name, nats_server, dequeue_timeout)
@classmethod
@asynccontextmanager
......@@ -81,75 +70,34 @@ class NATSQueue:
cls._instance = None
async def connect(self):
"""Establish connection and create stream if needed"""
try:
if self._nc is None:
self._nc = NATS()
await self._nc.connect(self.nats_url)
self._js = self._nc.jetstream()
# Check if stream exists, if not create it
try:
await self._js.stream_info(self._stream_name)
except NotFoundError:
await self._js.add_stream(
name=self._stream_name, subjects=[self._subject]
)
# Create persistent subscriber
self._subscriber = await self._js.pull_subscribe(
f"{self._stream_name}.queue", durable="worker-group"
)
except NatsError as e:
await self.close()
raise ConnectionError(f"Failed to connect to NATS: {e}")
await self.nats_q.connect()
async def ensure_connection(self):
"""Ensure we have an active connection"""
if self._nc is None or self._nc.is_closed:
await self.connect()
await self.nats_q.ensure_connection()
async def close(self):
"""Close the connection when done"""
if self._nc:
await self._nc.close()
self._nc = None
self._js = None
self._subscriber = None
await self.nats_q.close()
# TODO: is enqueue/dequeue_object a better name for a general queue?
async def enqueue_task(self, task_data: bytes) -> None:
"""
Enqueue a task using msgspec-encoded data
"""
await self.ensure_connection()
try:
await self._js.publish(f"{self._stream_name}.queue", task_data) # type: ignore
except NatsError as e:
raise RuntimeError(f"Failed to enqueue task: {e}")
await self.nats_q.enqueue_task(task_data)
async def dequeue_task(self) -> Optional[bytes]:
"""Dequeue and return a task as raw bytes, to be decoded with msgspec"""
await self.ensure_connection()
try:
msgs = await self._subscriber.fetch(1, timeout=self.dequeue_timeout) # type: ignore
if msgs:
msg = msgs[0]
await msg.ack()
return msg.data
return None
except asyncio.TimeoutError:
return None
except NatsError as e:
raise RuntimeError(f"Failed to dequeue task: {e}")
async def dequeue_task(self, timeout: Optional[float] = None) -> Optional[bytes]:
return await self.nats_q.dequeue_task(timeout)
async def get_queue_size(self) -> int:
"""Get the number of messages currently in the queue"""
await self.ensure_connection()
return await self.nats_q.get_queue_size()
async def clear_queue(self) -> int:
try:
# Get consumer info to get pending messages count
consumer_info = await self._js.consumer_info( # type: ignore
self._stream_name, "worker-group"
)
# Return number of pending messages (real-time queue size)
return consumer_info.num_pending
except NatsError as e:
raise RuntimeError(f"Failed to get queue size: {e}")
cleared_count = 0
# Continue until we can't dequeue any more messages
while True:
# use a small timeout
message = await self.dequeue_task(timeout=0.1)
if message is None:
break
cleared_count += 1
return cleared_count
except Exception as e:
raise RuntimeError(f"Failed to clear queue: {e}")
......@@ -28,7 +28,6 @@ requires-python = ">=3.10"
dependencies = [
"pydantic>=2.10.6,<2.11.0",
"uvloop>=0.21.0",
"nats-py>=2.6.0",
]
classifiers = [
"Development Status :: 4 - Beta",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment