Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
accac829
Unverified
Commit
accac829
authored
Jul 23, 2025
by
Lu Fang
Committed by
GitHub
Jul 23, 2025
Browse files
[Sampler] Introduce logprobs mode for logging (#21398)
Signed-off-by:
Lu Fang
<
lufang@fb.com
>
parent
23637dcd
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
83 additions
and
13 deletions
+83
-13
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+43
-0
vllm/config.py
vllm/config.py
+9
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+11
-7
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+15
-2
vllm/v1/sample/tpu/sampler.py
vllm/v1/sample/tpu/sampler.py
+1
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+2
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-2
No files found.
tests/v1/sample/test_logprobs.py
View file @
accac829
...
...
@@ -12,6 +12,7 @@ from tests.v1.sample.utils import (
assert_incr_detok_str_matches_non_incr_detok_str
,
compute_correct_cumulative_logprob
,
get_test_batch
)
from
vllm
import
SamplingParams
from
vllm.config
import
LogprobsMode
from
...conftest
import
HfRunner
,
VllmRunner
...
...
@@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts,
# prompt token
assert
prompt_logprobs
is
not
None
assert
len
(
prompt_token_ids
)
==
len
(
prompt_logprobs
)
@
pytest
.
mark
.
parametrize
(
"logprobs_mode"
,
[
"raw_logprobs"
,
"raw_logits"
,
"processed_logprobs"
,
"processed_logits"
])
def
test_logprobs_mode
(
logprobs_mode
:
LogprobsMode
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test with LLM engine with different logprobs_mode.
For logprobs, we should have non-positive values.
For logits, we should expect at least one positive values.
"""
from
vllm
import
LLM
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
"facebook/opt-125m"
,
max_logprobs
=
5
,
enable_prefix_caching
=
False
,
# 2 other llms alive during whole session
gpu_memory_utilization
=
0.05
,
max_model_len
=
16
,
logprobs_mode
=
logprobs_mode
)
vllm_sampling_params
=
SamplingParams
(
logprobs
=
1
)
results
=
llm
.
generate
([
"Hello world"
],
sampling_params
=
vllm_sampling_params
)
total_token_with_logprobs
=
0
positive_values
=
0
for
output
in
results
[
0
].
outputs
:
for
logprobs
in
output
.
logprobs
:
for
token_id
in
logprobs
:
logprob
=
logprobs
[
token_id
]
if
"logprobs"
in
logprobs_mode
:
assert
logprob
.
logprob
<=
0
if
logprob
.
logprob
>
0
:
positive_values
=
positive_values
+
1
total_token_with_logprobs
=
total_token_with_logprobs
+
1
assert
total_token_with_logprobs
>=
len
(
results
[
0
].
outputs
)
if
"logits"
in
logprobs_mode
:
assert
positive_values
>
0
del
llm
vllm/config.py
View file @
accac829
...
...
@@ -219,6 +219,8 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode
=
Literal
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]
ModelDType
=
Literal
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
]
LogprobsMode
=
Literal
[
"raw_logprobs"
,
"raw_logits"
,
"processed_logprobs"
,
"processed_logits"
]
@
config
...
...
@@ -316,6 +318,13 @@ class ModelConfig:
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API."""
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
Raw means the values before applying logit processors, like bad words.
Processed means the values after applying such processors.
"""
disable_sliding_window
:
bool
=
False
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
...
...
vllm/engine/arg_utils.py
View file @
accac829
...
...
@@ -26,13 +26,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DetailedTraceModules
,
Device
,
DeviceConfig
,
DistributedExecutorBackend
,
GuidedDecodingBackend
,
GuidedDecodingBackendV1
,
HfOverrides
,
KVEventsConfig
,
KVTransferConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
Model
Config
,
Model
DType
,
Model
Impl
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
Pooler
Config
,
P
refixCachingHashAlgo
,
PromptAdapterConfig
,
SchedulerConfig
,
SchedulerPolicy
,
SpeculativeConfig
,
TaskOption
,
TokenizerMode
,
VllmConfig
,
get_attr_docs
,
get_field
)
KVTransferConfig
,
LoadConfig
,
LoadFormat
,
LogprobsMode
,
LoRA
Config
,
Model
Config
,
Model
DType
,
ModelImpl
,
MultiModalConfig
,
Observability
Config
,
P
arallelConfig
,
PoolerConfig
,
PrefixCachingHashAlgo
,
PromptAdapterConfig
,
SchedulerConfig
,
SchedulerPolicy
,
SpeculativeConfig
,
TaskOption
,
TokenizerMode
,
VllmConfig
,
get_attr_docs
,
get_field
)
from
vllm.logger
import
init_logger
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.plugins
import
load_general_plugins
...
...
@@ -324,6 +324,7 @@ class EngineArgs:
SchedulerConfig
.
long_prefill_token_threshold
max_num_seqs
:
Optional
[
int
]
=
SchedulerConfig
.
max_num_seqs
max_logprobs
:
int
=
ModelConfig
.
max_logprobs
logprobs_mode
:
LogprobsMode
=
ModelConfig
.
logprobs_mode
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
ModelConfig
.
revision
code_revision
:
Optional
[
str
]
=
ModelConfig
.
code_revision
...
...
@@ -490,6 +491,8 @@ class EngineArgs:
**
model_kwargs
[
"max_seq_len_to_capture"
])
model_group
.
add_argument
(
"--max-logprobs"
,
**
model_kwargs
[
"max_logprobs"
])
model_group
.
add_argument
(
"--logprobs-mode"
,
**
model_kwargs
[
"logprobs_mode"
])
model_group
.
add_argument
(
"--disable-sliding-window"
,
**
model_kwargs
[
"disable_sliding_window"
])
model_group
.
add_argument
(
"--disable-cascade-attn"
,
...
...
@@ -892,6 +895,7 @@ class EngineArgs:
enforce_eager
=
self
.
enforce_eager
,
max_seq_len_to_capture
=
self
.
max_seq_len_to_capture
,
max_logprobs
=
self
.
max_logprobs
,
logprobs_mode
=
self
.
logprobs_mode
,
disable_sliding_window
=
self
.
disable_sliding_window
,
disable_cascade_attn
=
self
.
disable_cascade_attn
,
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
...
...
vllm/v1/sample/sampler.py
View file @
accac829
...
...
@@ -5,6 +5,7 @@
import
torch
import
torch.nn
as
nn
from
vllm.config
import
LogprobsMode
from
vllm.utils
import
is_pin_memory_available
from
vllm.v1.outputs
import
LogprobsTensors
,
SamplerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
...
@@ -18,10 +19,11 @@ _SAMPLING_EPS = 1e-5
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
,
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
):
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
logprobs_mode
=
logprobs_mode
def
forward
(
self
,
...
...
@@ -36,7 +38,10 @@ class Sampler(nn.Module):
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
num_logprobs
=
sampling_metadata
.
max_num_logprobs
if
num_logprobs
is
not
None
:
raw_logprobs
=
self
.
compute_logprobs
(
logits
)
if
self
.
logprobs_mode
==
"raw_logprobs"
:
raw_logprobs
=
self
.
compute_logprobs
(
logits
)
elif
self
.
logprobs_mode
==
"raw_logits"
:
raw_logprobs
=
logits
.
clone
()
# Use float32 for the logits.
logits
=
logits
.
to
(
torch
.
float32
)
...
...
@@ -51,6 +56,14 @@ class Sampler(nn.Module):
# Apply penalties (e.g., min_tokens, freq_penalties).
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
)
# Get the process logprobs or logits.
if
num_logprobs
is
not
None
:
if
self
.
logprobs_mode
==
"processed_logprobs"
:
raw_logprobs
=
self
.
compute_logprobs
(
logits
)
elif
self
.
logprobs_mode
==
"processed_logits"
:
raw_logprobs
=
logits
.
clone
()
# Sample the next token.
sampled
=
self
.
sample
(
logits
,
sampling_metadata
)
# Convert sampled token ids to int64 (long) type to ensure compatibility
...
...
vllm/v1/sample/tpu/sampler.py
View file @
accac829
...
...
@@ -15,6 +15,7 @@ _SAMPLING_EPS = 1e-5
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
# TODO(houseroad): Add support for logprobs_mode.
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
accac829
...
...
@@ -389,7 +389,7 @@ class InputBatch:
def
remove_request
(
self
,
req_id
:
str
)
->
Optional
[
int
]:
"""This method must always be followed by a call to condense().
Args:
req_id: request to remove
...
...
@@ -590,7 +590,7 @@ class InputBatch:
def
refresh_metadata
(
self
):
"""Apply batch updates, reset input batch at end of step
* Apply batch add/remove/permute to logits procs' states
* If batch state is modified, update sampling metadata
"""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
accac829
...
...
@@ -151,7 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
encoder_cache_size
=
encoder_cache_size
# Sampler
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
self
.
eplb_state
:
Optional
[
EplbState
]
=
None
"""
...
...
@@ -1996,7 +1996,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection
- during profile_run
- during DP rank dummy run
- during DP rank dummy run
"""
dp_size
=
self
.
vllm_config
.
parallel_config
.
data_parallel_size
randomize_inputs
=
envs
.
VLLM_RANDOMIZE_DP_DUMMY_INPUTS
and
dp_size
>
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment