Unverified Commit ae2e93f8 authored by Sumanth R Hegde's avatar Sumanth R Hegde Committed by GitHub
Browse files

[Fix] Fix `logprobs=0` handling for `/inference/v1/generate` endpoint (#34010)


Signed-off-by: default avatarSumanthRH <sumanthrh99@gmail.com>
parent 9e9acce5
......@@ -87,6 +87,32 @@ async def test_generate_endpoint(client):
assert "choices" in data
@pytest.mark.asyncio
@pytest.mark.parametrize("logprobs_value", [0, 1, 5])
async def test_generate_logprobs(client, logprobs_value):
payload = {
"model": MODEL_NAME,
"token_ids": [1, 2, 3],
"sampling_params": {
"max_tokens": 5,
"temperature": 0.0,
"logprobs": logprobs_value,
},
"stream": False,
}
resp = await client.post(GEN_ENDPOINT, json=payload)
resp.raise_for_status()
data = resp.json()
choice = data["choices"][0]
assert choice["logprobs"] is not None
logprobs_content = choice["logprobs"]["content"]
assert len(logprobs_content) == len(choice["token_ids"])
for entry in logprobs_content:
assert "logprob" in entry
assert len(entry["top_logprobs"]) >= 1
assert len(entry["top_logprobs"]) == max(logprobs_value, 1)
@pytest.mark.asyncio
async def test_same_response_as_chat_completions(client, tokenizer, messages):
token_ids = tokenizer.apply_chat_template(
......
......@@ -184,7 +184,7 @@ class ServingTokens(OpenAIServing):
out_logprobs = output.logprobs
# This is top_logprobs in completions API
if sampling_params.logprobs:
if sampling_params.logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_tokens_logprobs(
token_ids=token_ids,
......@@ -284,7 +284,8 @@ class ServingTokens(OpenAIServing):
logprob=max(p[1].logprob, -9999.0),
)
for i, p in enumerate(step_top_logprobs.items())
if num_output_top_logprobs and i < num_output_top_logprobs
if num_output_top_logprobs is not None
and i < max(num_output_top_logprobs, 1)
],
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment