Unverified Commit 06b2550c authored by Toshiki Kataoka's avatar Toshiki Kataoka Committed by GitHub
Browse files

[Bugfix] Support `prompt_logprobs==0` (#5217)

parent f775a07e
......@@ -224,7 +224,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 1
assert len(choice.logprobs.top_logprobs[0]) == 1
@pytest.mark.asyncio
......@@ -246,7 +246,7 @@ async def test_some_logprobs(server, client: openai.AsyncOpenAI,
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 6
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
@pytest.mark.asyncio
......@@ -1217,8 +1217,9 @@ number: "1" | "2"
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
@pytest.mark.parametrize("logprobs_arg", [1, 0])
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
model_name: str):
model_name: str, logprobs_arg: int):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
......@@ -1227,7 +1228,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=1)
logprobs=logprobs_arg)
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
......@@ -1240,6 +1241,9 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) > 5
......
......@@ -312,7 +312,7 @@ class OpenAIServingCompletion(OpenAIServing):
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
if request.logprobs is not None else None)
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
......
......@@ -233,7 +233,7 @@ def _prepare_seq_groups(
logits = hidden_states[selected_token_indices]
"""
if sampling_params.prompt_logprobs:
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(model_output_idx, model_output_idx + prompt_logprob_len))
model_output_idx += prompt_logprob_len
......
......@@ -427,7 +427,7 @@ class ModelRunner:
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
is not None else 1))
mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment