Unverified Commit 475e2e37 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

[PD] Fix server crash when using batch requests (#5531)

parent fba86b6b
...@@ -161,12 +161,24 @@ async def handle_generate_request(request_data: dict): ...@@ -161,12 +161,24 @@ async def handle_generate_request(request_data: dict):
parsed_url = urllib.parse.urlparse(prefill_server) parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname hostname = parsed_url.hostname
modified_request = request_data.copy() modified_request = request_data.copy()
modified_request.update(
{ batch_size = _get_request_batch_size(modified_request)
"bootstrap_host": hostname, if batch_size is not None:
"bootstrap_room": random.randint(0, 2**63 - 1), modified_request.update(
} {
) "bootstrap_host": [hostname] * batch_size,
"bootstrap_room": [
_generate_bootstrap_room() for _ in range(batch_size)
],
}
)
else:
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False): if request_data.get("stream", False):
return await load_balancer.generate_stream( return await load_balancer.generate_stream(
...@@ -178,6 +190,19 @@ async def handle_generate_request(request_data: dict): ...@@ -178,6 +190,19 @@ async def handle_generate_request(request_data: dict):
) )
def _generate_bootstrap_room():
return random.randint(0, 2**63 - 1)
# We may utilize `GenerateReqInput`'s logic later
def _get_request_batch_size(request):
if (text := request.get("text")) is not None:
return None if isinstance(text, str) else len(text)
if (input_ids := request.get("input_ids")) is not None:
return None if isinstance(input_ids[0], int) else len(input_ids)
return None
@app.get("/v1/models") @app.get("/v1/models")
async def get_models(): async def get_models():
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
......
...@@ -96,8 +96,8 @@ class GenerateReqInput: ...@@ -96,8 +96,8 @@ class GenerateReqInput:
return_hidden_states: bool = False return_hidden_states: bool = False
# For disaggregated inference # For disaggregated inference
bootstrap_host: Optional[str] = None bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_room: Optional[int] = None bootstrap_room: Optional[Union[List[int], int]] = None
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
""" """
...@@ -397,6 +397,12 @@ class GenerateReqInput: ...@@ -397,6 +397,12 @@ class GenerateReqInput:
else None else None
), ),
return_hidden_states=self.return_hidden_states, return_hidden_states=self.return_hidden_states,
bootstrap_host=(
self.bootstrap_host[i] if self.bootstrap_host is not None else None
),
bootstrap_room=(
self.bootstrap_room[i] if self.bootstrap_room is not None else None
),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment