Unverified Commit 54de71d0 authored by 22quinn's avatar 22quinn Committed by GitHub
Browse files

[Sampler] Support returning all logprobs or logits (#21792)


Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent fed5849d
......@@ -429,6 +429,33 @@ def test_zero_logprobs(vllm_model, example_prompts,
assert len(prompt_token_ids) == len(prompt_logprobs)
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
"""Engine should return all vocabulary logprobs
Args:
example_prompts: list of example prompts (test fixture)
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
runner = VllmRunner(
"facebook/opt-125m",
max_logprobs=-1,
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.15,
max_model_len=256)
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
logprobs=-1)
results_logprobs_all = runner.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_all)
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
for i in range(len(results_logprobs_all)):
logprobs = results_logprobs_all[i].outputs[0].logprobs
assert logprobs is not None
for logprob in logprobs:
assert len(logprob) == vocab_size
@pytest.mark.parametrize(
"logprobs_mode",
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
......
......@@ -377,7 +377,8 @@ class ModelConfig:
max_logprobs: int = 20
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API."""
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
logprobs_mode: LogprobsMode = "raw_logprobs"
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
......@@ -1585,7 +1586,7 @@ class ModelConfig:
"""
This method attempts to retrieve the non-default values of the
generation config for this model.
The generation config can contain information about special tokens, as
well as sampling parameters. Which is why this method exists separately
to `get_diff_sampling_param`.
......@@ -2066,7 +2067,7 @@ class ParallelConfig:
and when data_parallel_size > 0. Enables running an AsyncLLM
and API server on a "per-node" basis where vLLM load balances
between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with
between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
......
......@@ -156,6 +156,7 @@ class SamplingParams(
Note that the implementation follows the OpenAI API: The API will
always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response.
When set to -1, return all `vocab_size` log probabilities.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
......@@ -414,9 +415,10 @@ class SamplingParams(
raise ValueError(
f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
if self.logprobs is not None and self.logprobs < 0:
if (self.logprobs is not None and self.logprobs != -1
and self.logprobs < 0):
raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.")
f"logprobs must be non-negative or -1, got {self.logprobs}.")
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")
......
......@@ -138,7 +138,7 @@ class LogprobsProcessor:
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt logprobs
The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them.
......@@ -176,7 +176,8 @@ class LogprobsProcessor:
Returns:
dict[token id, Logprob]
"""
if num_logprobs == -1:
num_logprobs = len(logprobs)
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
......
......@@ -65,8 +65,11 @@ class Processor:
params: SamplingParams,
) -> None:
max_logprobs = self.model_config.max_logprobs
if max_logprobs == -1:
return
# Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs:
if params.logprobs and (params.logprobs == -1
or params.logprobs > max_logprobs):
raise ValueError(
f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
......
......@@ -337,7 +337,9 @@ class InputBatch:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
self.num_logprobs[req_id] = (self.vocab_size
if sampling_params.logprobs == -1
else sampling_params.logprobs)
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment