"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "82f60cc781afcfe735028ae72aeb04f90b687b78"
Unverified Commit 54de71d0 authored by 22quinn's avatar 22quinn Committed by GitHub
Browse files

[Sampler] Support returning all logprobs or logits (#21792)


Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent fed5849d
...@@ -429,6 +429,33 @@ def test_zero_logprobs(vllm_model, example_prompts, ...@@ -429,6 +429,33 @@ def test_zero_logprobs(vllm_model, example_prompts,
assert len(prompt_token_ids) == len(prompt_logprobs) assert len(prompt_token_ids) == len(prompt_logprobs)
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
"""Engine should return all vocabulary logprobs
Args:
example_prompts: list of example prompts (test fixture)
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
runner = VllmRunner(
"facebook/opt-125m",
max_logprobs=-1,
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.15,
max_model_len=256)
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
logprobs=-1)
results_logprobs_all = runner.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_all)
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
for i in range(len(results_logprobs_all)):
logprobs = results_logprobs_all[i].outputs[0].logprobs
assert logprobs is not None
for logprob in logprobs:
assert len(logprob) == vocab_size
@pytest.mark.parametrize( @pytest.mark.parametrize(
"logprobs_mode", "logprobs_mode",
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]) ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
......
...@@ -377,7 +377,8 @@ class ModelConfig: ...@@ -377,7 +377,8 @@ class ModelConfig:
max_logprobs: int = 20 max_logprobs: int = 20
"""Maximum number of log probabilities to return when `logprobs` is """Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API.""" OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
logprobs_mode: LogprobsMode = "raw_logprobs" logprobs_mode: LogprobsMode = "raw_logprobs"
"""Indicates the content returned in the logprobs and prompt_logprobs. """Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode: Supported mode:
...@@ -1585,7 +1586,7 @@ class ModelConfig: ...@@ -1585,7 +1586,7 @@ class ModelConfig:
""" """
This method attempts to retrieve the non-default values of the This method attempts to retrieve the non-default values of the
generation config for this model. generation config for this model.
The generation config can contain information about special tokens, as The generation config can contain information about special tokens, as
well as sampling parameters. Which is why this method exists separately well as sampling parameters. Which is why this method exists separately
to `get_diff_sampling_param`. to `get_diff_sampling_param`.
...@@ -2066,7 +2067,7 @@ class ParallelConfig: ...@@ -2066,7 +2067,7 @@ class ParallelConfig:
and when data_parallel_size > 0. Enables running an AsyncLLM and when data_parallel_size > 0. Enables running an AsyncLLM
and API server on a "per-node" basis where vLLM load balances and API server on a "per-node" basis where vLLM load balances
between local data parallel ranks, but an external LB balances between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank.""" --data-parallel-start-rank."""
enable_expert_parallel: bool = False enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers.""" """Use expert parallelism instead of tensor parallelism for MoE layers."""
......
...@@ -156,6 +156,7 @@ class SamplingParams( ...@@ -156,6 +156,7 @@ class SamplingParams(
Note that the implementation follows the OpenAI API: The API will Note that the implementation follows the OpenAI API: The API will
always return the log probability of the sampled token, so there always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response. may be up to `logprobs+1` elements in the response.
When set to -1, return all `vocab_size` log probabilities.
prompt_logprobs: Number of log probabilities to return per prompt token. prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True. detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output. skip_special_tokens: Whether to skip special tokens in the output.
...@@ -414,9 +415,10 @@ class SamplingParams( ...@@ -414,9 +415,10 @@ class SamplingParams(
raise ValueError( raise ValueError(
f"min_tokens must be less than or equal to " f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}.") f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
if self.logprobs is not None and self.logprobs < 0: if (self.logprobs is not None and self.logprobs != -1
and self.logprobs < 0):
raise ValueError( raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.") f"logprobs must be non-negative or -1, got {self.logprobs}.")
if self.prompt_logprobs is not None and self.prompt_logprobs < 0: if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got " raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.") f"{self.prompt_logprobs}.")
......
...@@ -138,7 +138,7 @@ class LogprobsProcessor: ...@@ -138,7 +138,7 @@ class LogprobsProcessor:
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt logprobs """Pop and return all request prompt logprobs
The logprobs processor aggregates prompt chunk logprobs The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them. all prompt logprobs at once and then forgets them.
...@@ -176,7 +176,8 @@ class LogprobsProcessor: ...@@ -176,7 +176,8 @@ class LogprobsProcessor:
Returns: Returns:
dict[token id, Logprob] dict[token id, Logprob]
""" """
if num_logprobs == -1:
num_logprobs = len(logprobs)
# We do not need a special case for the sampled token # We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data # being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once. # into a dictionary twice is the same as doing it once.
......
...@@ -65,8 +65,11 @@ class Processor: ...@@ -65,8 +65,11 @@ class Processor:
params: SamplingParams, params: SamplingParams,
) -> None: ) -> None:
max_logprobs = self.model_config.max_logprobs max_logprobs = self.model_config.max_logprobs
if max_logprobs == -1:
return
# Validate sample logprobs. # Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs: if params.logprobs and (params.logprobs == -1
or params.logprobs > max_logprobs):
raise ValueError( raise ValueError(
f"Requested sample logprobs of {params.logprobs}, " f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}") f"which is greater than max allowed: {max_logprobs}")
......
...@@ -337,7 +337,9 @@ class InputBatch: ...@@ -337,7 +337,9 @@ class InputBatch:
self.generators[req_index] = request.generator self.generators[req_index] = request.generator
if sampling_params.logprobs is not None: if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs self.num_logprobs[req_id] = (self.vocab_size
if sampling_params.logprobs == -1
else sampling_params.logprobs)
if sampling_params.prompt_logprobs is not None: if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[ self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs req_id] = sampling_params.prompt_logprobs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment