"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "c11013db8b76bebaaed07d4791f693998e398925"
Unverified Commit bec060fd authored by Andrew Sansom's avatar Andrew Sansom Committed by GitHub
Browse files

Mark prompt logprobs as incompatible with prompt embeds at API level (#25077)


Signed-off-by: default avatarAndrew Sansom <andrew@protopia.ai>
parent 52bc9d5b
...@@ -228,3 +228,20 @@ async def test_completions_with_logprobs_and_prompt_embeds( ...@@ -228,3 +228,20 @@ async def test_completions_with_logprobs_and_prompt_embeds(
assert max(logprobs_arg, assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1 1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) == 5 assert len(logprobs.tokens) == 5
@pytest.mark.asyncio
async def test_prompt_logprobs_raises_error(
client_with_prompt_embeds: openai.AsyncOpenAI):
with pytest.raises(BadRequestError, match="not compatible"):
encoded_embeds = create_dummy_embeds()
await client_with_prompt_embeds.completions.create(
model=MODEL_NAME,
prompt="",
max_tokens=5,
temperature=0.0,
extra_body={
"prompt_embeds": encoded_embeds,
"prompt_logprobs": True
},
)
...@@ -671,10 +671,13 @@ class LLMEngine: ...@@ -671,10 +671,13 @@ class LLMEngine:
arrival_time = time.time() arrival_time = time.time()
if (isinstance(prompt, dict) if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None and prompt.get("prompt_embeds", None) is not None):
and not prompt.get("prompt_token_ids", None)): if not prompt.get("prompt_token_ids", None):
seq_len = prompt["prompt_embeds"].shape[0] seq_len = prompt["prompt_embeds"].shape[0]
prompt["prompt_token_ids"] = [0] * seq_len prompt["prompt_token_ids"] = [0] * seq_len
if params.prompt_logprobs is not None:
raise ValueError(
"prompt_logprobs is not compatible with prompt embeds.")
processed_inputs = self.input_preprocessor.preprocess( processed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
......
...@@ -112,6 +112,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -112,6 +112,11 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"Echo is unsupported with prompt embeds.") "Echo is unsupported with prompt embeds.")
if (request.prompt_logprobs is not None
and request.prompt_embeds is not None):
return self.create_error_response(
"prompt_logprobs is not compatible with prompt embeds.")
request_id = ( request_id = (
f"cmpl-" f"cmpl-"
f"{self._base_request_id(raw_request, request.request_id)}") f"{self._base_request_id(raw_request, request.request_id)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment