"vscode:/vscode.git/clone" did not exist on "b570f2c17130c30be56a276aa0d1ed11a096dad1"
Unverified Commit 1dc6864f authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Support completion endpoint (#6729)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent 485a023b
...@@ -274,8 +274,7 @@ async def handle_generate_request(request_data: dict): ...@@ -274,8 +274,7 @@ async def handle_generate_request(request_data: dict):
) )
@app.post("/v1/chat/completions") async def _forward_to_backend(request_data: dict, endpoint_name: str):
async def handle_completion_request(request_data: dict):
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair() prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
# Parse and transform prefill_server for bootstrap data # Parse and transform prefill_server for bootstrap data
...@@ -286,7 +285,7 @@ async def handle_completion_request(request_data: dict): ...@@ -286,7 +285,7 @@ async def handle_completion_request(request_data: dict):
{ {
"bootstrap_host": hostname, "bootstrap_host": hostname,
"bootstrap_port": bootstrap_port, "bootstrap_port": bootstrap_port,
"bootstrap_room": random.randint(0, 2**63 - 1), "bootstrap_room": _generate_bootstrap_room(),
} }
) )
...@@ -295,17 +294,27 @@ async def handle_completion_request(request_data: dict): ...@@ -295,17 +294,27 @@ async def handle_completion_request(request_data: dict):
modified_request, modified_request,
prefill_server, prefill_server,
decode_server, decode_server,
endpoint="v1/chat/completions", endpoint=endpoint_name,
) )
else: else:
return await load_balancer.generate( return await load_balancer.generate(
modified_request, modified_request,
prefill_server, prefill_server,
decode_server, decode_server,
endpoint="v1/chat/completions", endpoint=endpoint_name,
) )
@app.post("/v1/chat/completions")
async def handle_chat_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/chat/completions")
@app.post("/v1/completions")
async def handle_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/completions")
def _generate_bootstrap_room(): def _generate_bootstrap_room():
return random.randint(0, 2**63 - 1) return random.randint(0, 2**63 - 1)
......
...@@ -604,6 +604,9 @@ def v1_generate_request( ...@@ -604,6 +604,9 @@ def v1_generate_request(
stream=all_requests[0].stream, stream=all_requests[0].stream,
rid=request_ids, rid=request_ids,
lora_path=lora_paths, lora_path=lora_paths,
bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room,
) )
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
......
...@@ -183,12 +183,17 @@ class CompletionRequest(BaseModel): ...@@ -183,12 +183,17 @@ class CompletionRequest(BaseModel):
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None session_params: Optional[Dict] = None
# For PD disaggregation
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
index: int index: int
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter"] finish_reason: Literal["stop", "length", "content_filter", "abort"]
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment