Unverified Commit 30bbfe0c authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: support multiple endpoints (#857)

parent 974201c8
...@@ -239,7 +239,7 @@ def main( ...@@ -239,7 +239,7 @@ def main(
dynamo_context["component"] = component dynamo_context["component"] = component
dynamo_context["endpoints"] = endpoints dynamo_context["endpoints"] = endpoints
class_instance = service.inner() class_instance = service.inner()
twm = [] dynamo_handlers = []
for name, endpoint in dynamo_endpoints.items(): for name, endpoint in dynamo_endpoints.items():
bound_method = endpoint.func.__get__(class_instance) bound_method = endpoint.func.__get__(class_instance)
# Only pass request type for now, use Any for response # Only pass request type for now, use Any for response
...@@ -248,7 +248,7 @@ def main( ...@@ -248,7 +248,7 @@ def main(
dynamo_wrapped_method = dynamo_endpoint(endpoint.request_type, Any)( dynamo_wrapped_method = dynamo_endpoint(endpoint.request_type, Any)(
bound_method bound_method
) )
twm.append(dynamo_wrapped_method) dynamo_handlers.append(dynamo_wrapped_method)
# Run startup hooks before setting up endpoints # Run startup hooks before setting up endpoints
for name, member in vars(class_instance.__class__).items(): for name, member in vars(class_instance.__class__).items():
if callable(member) and getattr( if callable(member) and getattr(
...@@ -280,7 +280,14 @@ def main( ...@@ -280,7 +280,14 @@ def main(
logger.info( logger.info(
f"Appended lease {lease.id()}/{lease.id():x} to {watcher_name}" f"Appended lease {lease.id()}/{lease.id():x} to {watcher_name}"
) )
result = await endpoints[0].serve_endpoint(twm[0], lease) # Launch serve_endpoint for all endpoints concurrently
tasks = [
endpoint.serve_endpoint(handler, lease)
for endpoint, handler in zip(endpoints, dynamo_handlers)
]
# Wait for all tasks to complete
await asyncio.gather(*tasks)
if class_instance.__class__.__name__ == "PrefillWorker": if class_instance.__class__.__name__ == "PrefillWorker":
await asyncio.wait_for(class_instance.task, timeout=None) await asyncio.wait_for(class_instance.task, timeout=None)
......
...@@ -81,6 +81,15 @@ class Backend: ...@@ -81,6 +81,15 @@ class Backend:
for token in text.split(): for token in text.split():
yield f"Backend: {token}" yield f"Backend: {token}"
@dynamo_endpoint()
async def generate_v2(self, req: RequestType):
"""Generate tokens."""
req_text = req.text
print(f"Backend received: {req_text}")
text = f"{req_text}-back"
for token in text.split():
yield f"Backend generate_v2: {token}"
@service( @service(
resources={"cpu": "2"}, resources={"cpu": "2"},
...@@ -129,6 +138,9 @@ class Middle: ...@@ -129,6 +138,9 @@ class Middle:
async for back_resp in self.backend.generate(txt.model_dump_json()): async for back_resp in self.backend.generate(txt.model_dump_json()):
print(f"Frontend received back_resp: {back_resp}") print(f"Frontend received back_resp: {back_resp}")
yield f"Frontend: {back_resp}" yield f"Frontend: {back_resp}"
async for back_resp in self.backend.generate_v2(txt.model_dump_json()):
print(f"Frontend received back_resp: {back_resp}")
yield f"Frontend: {back_resp}"
else: else:
async for back_resp in self.backend2.generate(txt.model_dump_json()): async for back_resp in self.backend2.generate(txt.model_dump_json()):
print(f"Frontend received back_resp: {back_resp}") print(f"Frontend received back_resp: {back_resp}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment