Unverified Commit 910abdbd authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Bugfix] fixed top_logprobs: -1 does not appear to work as intended (#26470)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent cddce79f
......@@ -7,12 +7,23 @@ import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from vllm.config import ModelConfig
from ...utils import RemoteOpenAIServer
# # any model with a chat template should work here
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
def get_vocab_size(model_name):
config = ModelConfig(
model=model_name,
seed=0,
dtype="float16",
)
return config.get_vocab_size()
@pytest.fixture(scope="module")
def server():
args = [
......@@ -107,6 +118,7 @@ async def test_top_logprobs(client: openai.AsyncOpenAI):
completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1,
extra_body={
"top_logprobs": -1,
"logprobs": "true",
......@@ -115,3 +127,6 @@ async def test_top_logprobs(client: openai.AsyncOpenAI):
assert completion.choices[0].logprobs is not None
assert completion.choices[0].logprobs.content is not None
assert len(completion.choices[0].logprobs.content) > 0
assert len(
completion.choices[0].logprobs.content[0].top_logprobs
) == get_vocab_size(MODEL_NAME)
......@@ -1643,7 +1643,7 @@ class OpenAIServingChat(OpenAIServing):
bytes=list(token.encode("utf-8", errors="replace")),
)
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
if (top_logprobs and i < top_logprobs or top_logprobs == -1)
]
def _create_chat_logprobs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment