Unverified Commit fe164d72 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: replace async queue with async iter and double decorator (#858)

parent ee2c5938
......@@ -22,6 +22,7 @@ from typing import Any, Dict, Optional, TypeVar
from _bentoml_sdk.service import Service
from _bentoml_sdk.service.dependency import Dependency
from dynamo.runtime import DistributedRuntime
from dynamo.sdk.lib.service import DynamoService
T = TypeVar("T")
......@@ -47,75 +48,31 @@ class DynamoClient:
if name not in self._dynamo_clients:
namespace, component_name = self._service.dynamo_address()
# Create async generator function that uses Queue for streaming
# Create async generator function that directly yields from the stream
async def get_stream(*args, **kwargs):
queue: asyncio.Queue = asyncio.Queue()
if self._runtime is not None:
# Use existing runtime if available
async def stream_worker():
try:
client = (
await self._runtime.namespace(namespace)
.component(component_name)
.endpoint(name)
.client()
)
# TODO: Potentially model dump for a user here so they can pass around Pydantic models
stream = await client.generate(*args, **kwargs)
async for item in stream:
data = item.data()
await queue.put(data)
await queue.put(None)
except Exception:
await queue.put(None)
raise
runtime = self._runtime
else:
# Create dynamo worker if no runtime
from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker()
async def stream_worker(runtime: DistributedRuntime):
try:
# Store runtime for future use
# Create new runtime and store it
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, False)
self._runtime = runtime
try:
client = (
await runtime.namespace(namespace)
.component(component_name)
.endpoint(name)
.client()
)
# Directly yield items from the stream
stream = await client.generate(*args, **kwargs)
async for item in stream:
data = item.data()
await queue.put(data)
await queue.put(None)
except Exception:
await queue.put(None)
raise
# Start worker task with error handling
worker_task = asyncio.create_task(stream_worker())
try:
# Yield items from queue until None received
while True:
item = await queue.get()
if item is None:
break
yield item
finally:
try:
await worker_task
except Exception:
raise
yield item.data()
except Exception as e:
raise e
self._dynamo_clients[name] = get_stream
return self._dynamo_clients[name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment