Unverified Commit 6d3b35fa authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Simplify mini LB (#4911)


Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
parent a73c4df4
"""
Minimal HTTP load balancer for prefill and decode servers for testing purpose.
Minimal HTTP load balancer for prefill and decode servers for testing.
"""
import asyncio
......@@ -22,64 +22,59 @@ class MiniLoadBalancer:
def select_pair(self):
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
async def generate_request(self, request_data):
prefill_server, decode_server = self.select_pair()
# Parse and transform prefill_server
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
bootstrap_host = f"{hostname}"
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": bootstrap_host,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
)
async def generate(
self, modified_request, prefill_server, decode_server
) -> ORJSONResponse:
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = [
session.post(f"{prefill_server}/generate", json=modified_request),
session.post(f"{decode_server}/generate", json=modified_request),
]
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
return ORJSONResponse(
content=await decode_response.json(),
status_code=decode_response.status,
)
prefill_response = None
decode_response = None
# Process responses as they arrive
for i, response in enumerate(asyncio.as_completed(tasks)):
response = await response
# Check if this is the prefill or decode response based on order created
if i == 0: # First completed task
if str(response.url).startswith(prefill_server):
prefill_response = response
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Prefill server error: Status {response.status} Details: {await response.text()}",
)
else:
decode_response = response
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Decode server error: Status {response.status} Details: {await response.text()}",
)
else: # Second completed task
if str(response.url).startswith(prefill_server):
prefill_response = response
else:
decode_response = response
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}",
)
return await decode_response.json()
async def generate_stream(self, modified_request, prefill_server, decode_server):
async def stream_results():
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=3600
) # Add timeout for request reliability
) as session:
try:
# Create the tasks for both prefill and decode requests
tasks = [
session.post(
f"{prefill_server}/generate", json=modified_request
),
session.post(
f"{decode_server}/generate", json=modified_request
),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks)
async for chunk in decode_response.content:
yield chunk
except Exception as e:
error_msg = {
"error": {"message": f"Stream processing error: {str(e)}"}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
finally:
if prefill_response is not None:
await prefill_response.release()
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
)
app = FastAPI()
......@@ -169,81 +164,14 @@ async def handle_generate_request(request_data: dict):
}
)
# Check if streaming is requested
if request_data.get("stream", False):
async def stream_results():
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=3600)
) as session:
try:
# Create the tasks
tasks = [
session.post(
f"{prefill_server}/generate", json=modified_request
),
session.post(
f"{decode_server}/generate", json=modified_request
),
]
prefill_response = None
decode_response = None
# Process responses as they arrive
for i, response_task in enumerate(asyncio.as_completed(tasks)):
response = await response_task
# Check the response immediately
if str(response.url).startswith(prefill_server):
prefill_response = response
if response.status != 200:
error_msg = {
"error": {
"message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
return
else:
decode_response = response
if response.status != 200:
error_msg = {
"error": {
"message": f"Decode server error: Status {response.status}"
}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
return
# Stream successful decode server response
async for line in decode_response.content:
yield line
yield b"data: [DONE]\n\n"
except Exception as e:
error_msg = {
"error": {"message": f"Stream processing error: {str(e)}"}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
finally:
if prefill_response is not None:
await prefill_response.release()
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
return await load_balancer.generate_stream(
modified_request, prefill_server, decode_server
)
else:
return await load_balancer.generate(
modified_request, prefill_server, decode_server
)
# Non-streaming case
result = await load_balancer.generate_request(request_data)
return ORJSONResponse(content=result)
@app.get("/v1/models")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment