Unverified Commit a178a0b4 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix duplicate id tool-call race condition (#29355)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent b8328b49
...@@ -273,6 +273,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -273,6 +273,11 @@ class OpenAIServingChat(OpenAIServing):
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = self._get_prompt_components(request_prompts[i]) prompt_text, _, _ = self._get_prompt_components(request_prompts[i])
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
)
if self.default_sampling_params is None: if self.default_sampling_params is None:
self.default_sampling_params = {} self.default_sampling_params = {}
...@@ -301,7 +306,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -301,7 +306,7 @@ class OpenAIServingChat(OpenAIServing):
) )
self._log_inputs( self._log_inputs(
request_id, sub_request_id,
request_prompts[i], request_prompts[i],
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
...@@ -316,14 +321,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -316,14 +321,14 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search( generator = self.beam_search(
prompt=engine_prompt, prompt=engine_prompt,
request_id=request_id, request_id=sub_request_id,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
) )
else: else:
engine_request, tokenization_kwargs = await self._process_inputs( engine_request, tokenization_kwargs = await self._process_inputs(
request_id, sub_request_id,
engine_prompt, engine_prompt,
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
...@@ -334,7 +339,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -334,7 +339,7 @@ class OpenAIServingChat(OpenAIServing):
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_request, engine_request,
sampling_params, sampling_params,
request_id, sub_request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
......
...@@ -1242,16 +1242,19 @@ class OpenAIServing: ...@@ -1242,16 +1242,19 @@ class OpenAIServing:
): ):
prompt_text, _, _ = self._get_prompt_components(request_prompt) prompt_text, _, _ = self._get_prompt_components(request_prompt)
orig_priority = priority orig_priority = priority
sub_request = 0
while True: while True:
# Ensure that each sub-request has a unique request id.
sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs( self._log_inputs(
request_id, sub_request_id,
request_prompt, request_prompt,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
trace_headers = kwargs.get("trace_headers") trace_headers = kwargs.get("trace_headers")
engine_request, tokenization_kwargs = await self._process_inputs( engine_request, tokenization_kwargs = await self._process_inputs(
request_id, sub_request_id,
engine_prompt, engine_prompt,
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
...@@ -1262,7 +1265,7 @@ class OpenAIServing: ...@@ -1262,7 +1265,7 @@ class OpenAIServing:
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_request, engine_request,
sampling_params, sampling_params,
request_id, sub_request_id,
lora_request=lora_request, lora_request=lora_request,
priority=priority, priority=priority,
prompt_text=prompt_text, prompt_text=prompt_text,
...@@ -1295,6 +1298,7 @@ class OpenAIServing: ...@@ -1295,6 +1298,7 @@ class OpenAIServing:
sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids) sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
# OPTIMIZATION # OPTIMIZATION
priority = orig_priority - 1 priority = orig_priority - 1
sub_request += 1
def _get_prompt_components( def _get_prompt_components(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment