Unverified Commit 6d3b35fa authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Simplify mini LB (#4911)


Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
parent a73c4df4
""" """
Minimal HTTP load balancer for prefill and decode servers for testing purpose. Minimal HTTP load balancer for prefill and decode servers for testing.
""" """
import asyncio import asyncio
...@@ -22,64 +22,59 @@ class MiniLoadBalancer: ...@@ -22,64 +22,59 @@ class MiniLoadBalancer:
def select_pair(self): def select_pair(self):
return random.choice(self.prefill_servers), random.choice(self.decode_servers) return random.choice(self.prefill_servers), random.choice(self.decode_servers)
async def generate_request(self, request_data): async def generate(
prefill_server, decode_server = self.select_pair() self, modified_request, prefill_server, decode_server
) -> ORJSONResponse:
# Parse and transform prefill_server
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
bootstrap_host = f"{hostname}"
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": bootstrap_host,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = [ tasks = [
session.post(f"{prefill_server}/generate", json=modified_request), session.post(f"{prefill_server}/generate", json=modified_request),
session.post(f"{decode_server}/generate", json=modified_request), session.post(f"{decode_server}/generate", json=modified_request),
] ]
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
prefill_response = None return ORJSONResponse(
decode_response = None content=await decode_response.json(),
status_code=decode_response.status,
# Process responses as they arrive
for i, response in enumerate(asyncio.as_completed(tasks)):
response = await response
# Check if this is the prefill or decode response based on order created
if i == 0: # First completed task
if str(response.url).startswith(prefill_server):
prefill_response = response
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Prefill server error: Status {response.status} Details: {await response.text()}",
) )
else:
decode_response = response
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Decode server error: Status {response.status} Details: {await response.text()}",
)
else: # Second completed task
if str(response.url).startswith(prefill_server):
prefill_response = response
else:
decode_response = response
if response.status != 200: async def generate_stream(self, modified_request, prefill_server, decode_server):
raise HTTPException( async def stream_results():
status_code=response.status, async with aiohttp.ClientSession(
detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}", timeout=aiohttp.ClientTimeout(
) total=3600
) # Add timeout for request reliability
) as session:
try:
# Create the tasks for both prefill and decode requests
tasks = [
session.post(
f"{prefill_server}/generate", json=modified_request
),
session.post(
f"{decode_server}/generate", json=modified_request
),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks)
async for chunk in decode_response.content:
yield chunk
except Exception as e:
error_msg = {
"error": {"message": f"Stream processing error: {str(e)}"}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
finally:
if prefill_response is not None:
await prefill_response.release()
return await decode_response.json() return StreamingResponse(
stream_results(),
media_type="text/event-stream",
)
app = FastAPI() app = FastAPI()
...@@ -169,82 +164,15 @@ async def handle_generate_request(request_data: dict): ...@@ -169,82 +164,15 @@ async def handle_generate_request(request_data: dict):
} }
) )
# Check if streaming is requested
if request_data.get("stream", False): if request_data.get("stream", False):
return await load_balancer.generate_stream(
async def stream_results(): modified_request, prefill_server, decode_server
async with aiohttp.ClientSession( )
timeout=aiohttp.ClientTimeout(total=3600)
) as session:
try:
# Create the tasks
tasks = [
session.post(
f"{prefill_server}/generate", json=modified_request
),
session.post(
f"{decode_server}/generate", json=modified_request
),
]
prefill_response = None
decode_response = None
# Process responses as they arrive
for i, response_task in enumerate(asyncio.as_completed(tasks)):
response = await response_task
# Check the response immediately
if str(response.url).startswith(prefill_server):
prefill_response = response
if response.status != 200:
error_msg = {
"error": {
"message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
return
else: else:
decode_response = response return await load_balancer.generate(
if response.status != 200: modified_request, prefill_server, decode_server
error_msg = {
"error": {
"message": f"Decode server error: Status {response.status}"
}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
return
# Stream successful decode server response
async for line in decode_response.content:
yield line
yield b"data: [DONE]\n\n"
except Exception as e:
error_msg = {
"error": {"message": f"Stream processing error: {str(e)}"}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
finally:
if prefill_response is not None:
await prefill_response.release()
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
) )
# Non-streaming case
result = await load_balancer.generate_request(request_data)
return ORJSONResponse(content=result)
@app.get("/v1/models") @app.get("/v1/models")
async def get_models(): async def get_models():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment