Unverified Commit fe164d72 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: replace async queue with async iter and double decorator (#858)

parent ee2c5938
...@@ -22,6 +22,7 @@ from typing import Any, Dict, Optional, TypeVar ...@@ -22,6 +22,7 @@ from typing import Any, Dict, Optional, TypeVar
from _bentoml_sdk.service import Service from _bentoml_sdk.service import Service
from _bentoml_sdk.service.dependency import Dependency from _bentoml_sdk.service.dependency import Dependency
from dynamo.runtime import DistributedRuntime
from dynamo.sdk.lib.service import DynamoService from dynamo.sdk.lib.service import DynamoService
T = TypeVar("T") T = TypeVar("T")
...@@ -47,75 +48,31 @@ class DynamoClient: ...@@ -47,75 +48,31 @@ class DynamoClient:
if name not in self._dynamo_clients: if name not in self._dynamo_clients:
namespace, component_name = self._service.dynamo_address() namespace, component_name = self._service.dynamo_address()
# Create async generator function that uses Queue for streaming # Create async generator function that directly yields from the stream
async def get_stream(*args, **kwargs): async def get_stream(*args, **kwargs):
queue: asyncio.Queue = asyncio.Queue()
if self._runtime is not None: if self._runtime is not None:
# Use existing runtime if available # Use existing runtime if available
async def stream_worker(): runtime = self._runtime
try:
client = (
await self._runtime.namespace(namespace)
.component(component_name)
.endpoint(name)
.client()
)
# TODO: Potentially model dump for a user here so they can pass around Pydantic models
stream = await client.generate(*args, **kwargs)
async for item in stream:
data = item.data()
await queue.put(data)
await queue.put(None)
except Exception:
await queue.put(None)
raise
else: else:
# Create dynamo worker if no runtime # Create new runtime and store it
from dynamo.runtime import DistributedRuntime, dynamo_worker loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, False)
@dynamo_worker()
async def stream_worker(runtime: DistributedRuntime):
try:
# Store runtime for future use
self._runtime = runtime self._runtime = runtime
try:
client = ( client = (
await runtime.namespace(namespace) await runtime.namespace(namespace)
.component(component_name) .component(component_name)
.endpoint(name) .endpoint(name)
.client() .client()
) )
# Directly yield items from the stream
stream = await client.generate(*args, **kwargs) stream = await client.generate(*args, **kwargs)
async for item in stream: async for item in stream:
data = item.data() yield item.data()
await queue.put(data) except Exception as e:
await queue.put(None) raise e
except Exception:
await queue.put(None)
raise
# Start worker task with error handling
worker_task = asyncio.create_task(stream_worker())
try:
# Yield items from queue until None received
while True:
item = await queue.get()
if item is None:
break
yield item
finally:
try:
await worker_task
except Exception:
raise
self._dynamo_clients[name] = get_stream self._dynamo_clients[name] = get_stream
return self._dynamo_clients[name] return self._dynamo_clients[name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment