Unverified Commit 3f57b00a authored by Yongtong Wu's avatar Yongtong Wu Committed by GitHub
Browse files

Support PD bootstrap fields on /v1/chat/completions endpoint (#5488)

parent 453d412c
...@@ -23,8 +23,9 @@ class MiniLoadBalancer: ...@@ -23,8 +23,9 @@ class MiniLoadBalancer:
return random.choice(self.prefill_servers), random.choice(self.decode_servers) return random.choice(self.prefill_servers), random.choice(self.decode_servers)
async def generate( async def generate(
self, modified_request, prefill_server, decode_server self, modified_request, prefill_server, decode_server, endpoint
) -> ORJSONResponse: ) -> ORJSONResponse:
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(
...@@ -32,8 +33,8 @@ class MiniLoadBalancer: ...@@ -32,8 +33,8 @@ class MiniLoadBalancer:
) # Add timeout for request reliability ) # Add timeout for request reliability
) as session: ) as session:
tasks = [ tasks = [
session.post(f"{prefill_server}/generate", json=modified_request), session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/generate", json=modified_request), session.post(f"{decode_server}/{endpoint}", json=modified_request),
] ]
# Wait for both responses to complete. Prefill should end first. # Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks) prefill_response, decode_response = await asyncio.gather(*tasks)
...@@ -43,7 +44,11 @@ class MiniLoadBalancer: ...@@ -43,7 +44,11 @@ class MiniLoadBalancer:
status_code=decode_response.status, status_code=decode_response.status,
) )
async def generate_stream(self, modified_request, prefill_server, decode_server): async def generate_stream(
self, modified_request, prefill_server, decode_server, endpoint="generate"
):
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async def stream_results(): async def stream_results():
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(
...@@ -54,10 +59,10 @@ class MiniLoadBalancer: ...@@ -54,10 +59,10 @@ class MiniLoadBalancer:
# Create the tasks for both prefill and decode requests # Create the tasks for both prefill and decode requests
tasks = [ tasks = [
session.post( session.post(
f"{prefill_server}/generate", json=modified_request f"{prefill_server}/{endpoint}", json=modified_request
), ),
session.post( session.post(
f"{decode_server}/generate", json=modified_request f"{decode_server}/{endpoint}", json=modified_request
), ),
] ]
# Wait for both responses to complete. Since this is streaming, they return immediately. # Wait for both responses to complete. Since this is streaming, they return immediately.
...@@ -190,6 +195,37 @@ async def handle_generate_request(request_data: dict): ...@@ -190,6 +195,37 @@ async def handle_generate_request(request_data: dict):
) )
@app.post("/v1/chat/completions")
async def handle_completion_request(request_data: dict):
prefill_server, decode_server = load_balancer.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
)
if request_data.get("stream", False):
return await load_balancer.generate_stream(
modified_request,
prefill_server,
decode_server,
endpoint="v1/chat/completions",
)
else:
return await load_balancer.generate(
modified_request,
prefill_server,
decode_server,
endpoint="v1/chat/completions",
)
def _generate_bootstrap_room(): def _generate_bootstrap_room():
return random.randint(0, 2**63 - 1) return random.randint(0, 2**63 - 1)
......
...@@ -1174,6 +1174,8 @@ def v1_chat_generate_request( ...@@ -1174,6 +1174,8 @@ def v1_chat_generate_request(
rid=request_ids, rid=request_ids,
modalities=modalities_list, modalities=modalities_list,
lora_path=lora_paths, lora_path=lora_paths,
bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_room=all_requests[0].bootstrap_room,
) )
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
......
...@@ -362,6 +362,10 @@ class ChatCompletionRequest(BaseModel): ...@@ -362,6 +362,10 @@ class ChatCompletionRequest(BaseModel):
separate_reasoning: bool = True separate_reasoning: bool = True
stream_reasoning: bool = True stream_reasoning: bool = True
# For PD disaggregation
bootstrap_host: Optional[str] = None
bootstrap_room: Optional[int] = None
class FunctionResponse(BaseModel): class FunctionResponse(BaseModel):
"""Function response.""" """Function response."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment