Unverified Commit 6d3b35fa authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Simplify mini LB (#4911)


Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
parent a73c4df4
""" """
Minimal HTTP load balancer for prefill and decode servers for testing purpose. Minimal HTTP load balancer for prefill and decode servers for testing.
""" """
import asyncio import asyncio
...@@ -22,64 +22,59 @@ class MiniLoadBalancer: ...@@ -22,64 +22,59 @@ class MiniLoadBalancer:
def select_pair(self): def select_pair(self):
return random.choice(self.prefill_servers), random.choice(self.decode_servers) return random.choice(self.prefill_servers), random.choice(self.decode_servers)
async def generate_request(self, request_data): async def generate(
prefill_server, decode_server = self.select_pair() self, modified_request, prefill_server, decode_server
) -> ORJSONResponse:
# Parse and transform prefill_server
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
bootstrap_host = f"{hostname}"
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": bootstrap_host,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = [ tasks = [
session.post(f"{prefill_server}/generate", json=modified_request), session.post(f"{prefill_server}/generate", json=modified_request),
session.post(f"{decode_server}/generate", json=modified_request), session.post(f"{decode_server}/generate", json=modified_request),
] ]
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
return ORJSONResponse(
content=await decode_response.json(),
status_code=decode_response.status,
)
prefill_response = None async def generate_stream(self, modified_request, prefill_server, decode_server):
decode_response = None async def stream_results():
async with aiohttp.ClientSession(
# Process responses as they arrive timeout=aiohttp.ClientTimeout(
for i, response in enumerate(asyncio.as_completed(tasks)): total=3600
response = await response ) # Add timeout for request reliability
# Check if this is the prefill or decode response based on order created ) as session:
if i == 0: # First completed task try:
if str(response.url).startswith(prefill_server): # Create the tasks for both prefill and decode requests
prefill_response = response tasks = [
if response.status != 200: session.post(
raise HTTPException( f"{prefill_server}/generate", json=modified_request
status_code=response.status, ),
detail=f"Prefill server error: Status {response.status} Details: {await response.text()}", session.post(
) f"{decode_server}/generate", json=modified_request
else: ),
decode_response = response ]
if response.status != 200: # Wait for both responses to complete. Since this is streaming, they return immediately.
raise HTTPException( prefill_response, decode_response = await asyncio.gather(*tasks)
status_code=response.status, async for chunk in decode_response.content:
detail=f"Decode server error: Status {response.status} Details: {await response.text()}", yield chunk
) except Exception as e:
else: # Second completed task error_msg = {
if str(response.url).startswith(prefill_server): "error": {"message": f"Stream processing error: {str(e)}"}
prefill_response = response }
else: yield b"data: " + orjson.dumps(
decode_response = response error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
if response.status != 200: finally:
raise HTTPException( if prefill_response is not None:
status_code=response.status, await prefill_response.release()
detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}",
) return StreamingResponse(
stream_results(),
return await decode_response.json() media_type="text/event-stream",
)
app = FastAPI() app = FastAPI()
...@@ -169,81 +164,14 @@ async def handle_generate_request(request_data: dict): ...@@ -169,81 +164,14 @@ async def handle_generate_request(request_data: dict):
} }
) )
# Check if streaming is requested
if request_data.get("stream", False): if request_data.get("stream", False):
return await load_balancer.generate_stream(
async def stream_results(): modified_request, prefill_server, decode_server
async with aiohttp.ClientSession( )
timeout=aiohttp.ClientTimeout(total=3600) else:
) as session: return await load_balancer.generate(
try: modified_request, prefill_server, decode_server
# Create the tasks
tasks = [
session.post(
f"{prefill_server}/generate", json=modified_request
),
session.post(
f"{decode_server}/generate", json=modified_request
),
]
prefill_response = None
decode_response = None
# Process responses as they arrive
for i, response_task in enumerate(asyncio.as_completed(tasks)):
response = await response_task
# Check the response immediately
if str(response.url).startswith(prefill_server):
prefill_response = response
if response.status != 200:
error_msg = {
"error": {
"message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
return
else:
decode_response = response
if response.status != 200:
error_msg = {
"error": {
"message": f"Decode server error: Status {response.status}"
}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
return
# Stream successful decode server response
async for line in decode_response.content:
yield line
yield b"data: [DONE]\n\n"
except Exception as e:
error_msg = {
"error": {"message": f"Stream processing error: {str(e)}"}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
finally:
if prefill_response is not None:
await prefill_response.release()
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
) )
# Non-streaming case
result = await load_balancer.generate_request(request_data)
return ORJSONResponse(content=result)
@app.get("/v1/models") @app.get("/v1/models")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment