Unverified Commit f510395b authored by Roy's avatar Roy Committed by GitHub
Browse files

[BugFix][Frontend] Fix completion logprobs=0 error (#3731)

parent 6110c39d
...@@ -199,6 +199,27 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, ...@@ -199,6 +199,27 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
completion.choices[0].text) >= 5 completion.choices[0].text) >= 5
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=0,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is None
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter # just test 1 lora hereafter
"model_name", "model_name",
......
...@@ -330,7 +330,7 @@ class LogProbs(BaseModel): ...@@ -330,7 +330,7 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
......
...@@ -251,9 +251,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -251,9 +251,6 @@ class OpenAIServingCompletion(OpenAIServing):
i]:] if output.logprobs else None i]:] if output.logprobs else None
if request.logprobs is not None: if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs( logprobs = self._create_logprobs(
token_ids=delta_token_ids, token_ids=delta_token_ids,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
......
...@@ -534,7 +534,8 @@ def _get_logprobs( ...@@ -534,7 +534,8 @@ def _get_logprobs(
# Prepare query indices # Prepare query indices
batched_logprobs_query_seq_indices: List[int] = [] batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = [] batched_logprobs_query_token_indices: List[int] = []
largest_num_logprobs = 0 # at least get one logprob for each token
largest_num_logprobs = 1
sample_idx = 0 sample_idx = 0
for i, (seq_group, sample_result) in enumerate( for i, (seq_group, sample_result) in enumerate(
zip(sampling_metadata.seq_groups, sample_results)): zip(sampling_metadata.seq_groups, sample_results)):
...@@ -643,7 +644,7 @@ def _get_logprobs( ...@@ -643,7 +644,7 @@ def _get_logprobs(
batched_ranks_query_result[query_result_idx].item()) batched_ranks_query_result[query_result_idx].item())
} }
query_result_idx += 1 query_result_idx += 1
if num_logprobs > 0: if num_logprobs >= 0:
sample_logprobs_dict.update( sample_logprobs_dict.update(
zip( zip(
top_token_ids[sample_idx + top_token_ids[sample_idx +
......
...@@ -111,7 +111,7 @@ class RequestOutput: ...@@ -111,7 +111,7 @@ class RequestOutput:
# NOTE: We need omit logprobs here explicitly because the sequence # NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the # always has the logprobs of the sampled tokens even if the
# logprobs are not requested. # logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs include_logprobs = seq_group.sampling_params.logprobs is not None
outputs = [ outputs = [
CompletionOutput(seqs.index(seq), seq.output_text, CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(), seq.get_output_token_ids(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment