Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
accac829
Unverified
Commit
accac829
authored
Jul 23, 2025
by
Lu Fang
Committed by
GitHub
Jul 23, 2025
Browse files
[Sampler] Introduce logprobs mode for logging (#21398)
Signed-off-by:
Lu Fang
<
lufang@fb.com
>
parent
23637dcd
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
83 additions
and
13 deletions
+83
-13
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+43
-0
vllm/config.py
vllm/config.py
+9
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+11
-7
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+15
-2
vllm/v1/sample/tpu/sampler.py
vllm/v1/sample/tpu/sampler.py
+1
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+2
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-2
No files found.
tests/v1/sample/test_logprobs.py
View file @
accac829
...
@@ -12,6 +12,7 @@ from tests.v1.sample.utils import (
...
@@ -12,6 +12,7 @@ from tests.v1.sample.utils import (
assert_incr_detok_str_matches_non_incr_detok_str
,
assert_incr_detok_str_matches_non_incr_detok_str
,
compute_correct_cumulative_logprob
,
get_test_batch
)
compute_correct_cumulative_logprob
,
get_test_batch
)
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.config
import
LogprobsMode
from
...conftest
import
HfRunner
,
VllmRunner
from
...conftest
import
HfRunner
,
VllmRunner
...
@@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts,
...
@@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts,
# prompt token
# prompt token
assert
prompt_logprobs
is
not
None
assert
prompt_logprobs
is
not
None
assert
len
(
prompt_token_ids
)
==
len
(
prompt_logprobs
)
assert
len
(
prompt_token_ids
)
==
len
(
prompt_logprobs
)
@
pytest
.
mark
.
parametrize
(
"logprobs_mode"
,
[
"raw_logprobs"
,
"raw_logits"
,
"processed_logprobs"
,
"processed_logits"
])
def
test_logprobs_mode
(
logprobs_mode
:
LogprobsMode
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test with LLM engine with different logprobs_mode.
For logprobs, we should have non-positive values.
For logits, we should expect at least one positive values.
"""
from
vllm
import
LLM
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
"facebook/opt-125m"
,
max_logprobs
=
5
,
enable_prefix_caching
=
False
,
# 2 other llms alive during whole session
gpu_memory_utilization
=
0.05
,
max_model_len
=
16
,
logprobs_mode
=
logprobs_mode
)
vllm_sampling_params
=
SamplingParams
(
logprobs
=
1
)
results
=
llm
.
generate
([
"Hello world"
],
sampling_params
=
vllm_sampling_params
)
total_token_with_logprobs
=
0
positive_values
=
0
for
output
in
results
[
0
].
outputs
:
for
logprobs
in
output
.
logprobs
:
for
token_id
in
logprobs
:
logprob
=
logprobs
[
token_id
]
if
"logprobs"
in
logprobs_mode
:
assert
logprob
.
logprob
<=
0
if
logprob
.
logprob
>
0
:
positive_values
=
positive_values
+
1
total_token_with_logprobs
=
total_token_with_logprobs
+
1
assert
total_token_with_logprobs
>=
len
(
results
[
0
].
outputs
)
if
"logits"
in
logprobs_mode
:
assert
positive_values
>
0
del
llm
vllm/config.py
View file @
accac829
...
@@ -219,6 +219,8 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
...
@@ -219,6 +219,8 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode
=
Literal
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]
TokenizerMode
=
Literal
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]
ModelDType
=
Literal
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
]
ModelDType
=
Literal
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
]
LogprobsMode
=
Literal
[
"raw_logprobs"
,
"raw_logits"
,
"processed_logprobs"
,
"processed_logits"
]
@
config
@
config
...
@@ -316,6 +318,13 @@ class ModelConfig:
...
@@ -316,6 +318,13 @@ class ModelConfig:
"""Maximum number of log probabilities to return when `logprobs` is
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API."""
OpenAI Chat Completions API."""
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
Raw means the values before applying logit processors, like bad words.
Processed means the values after applying such processors.
"""
disable_sliding_window
:
bool
=
False
disable_sliding_window
:
bool
=
False
"""Whether to disable sliding window. If True, we will disable the sliding
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
window functionality of the model, capping to sliding window size. If the
...
...
vllm/engine/arg_utils.py
View file @
accac829
...
@@ -26,13 +26,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
...
@@ -26,13 +26,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DetailedTraceModules
,
Device
,
DeviceConfig
,
DetailedTraceModules
,
Device
,
DeviceConfig
,
DistributedExecutorBackend
,
GuidedDecodingBackend
,
DistributedExecutorBackend
,
GuidedDecodingBackend
,
GuidedDecodingBackendV1
,
HfOverrides
,
KVEventsConfig
,
GuidedDecodingBackendV1
,
HfOverrides
,
KVEventsConfig
,
KVTransferConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
KVTransferConfig
,
LoadConfig
,
LoadFormat
,
Model
Config
,
Model
DType
,
Model
Impl
,
MultiModalConfig
,
LogprobsMode
,
LoRA
Config
,
Model
Config
,
Model
DType
,
ObservabilityConfig
,
ParallelConfig
,
Pooler
Config
,
ModelImpl
,
MultiModalConfig
,
Observability
Config
,
P
refixCachingHashAlgo
,
PromptAdapterConfig
,
P
arallelConfig
,
PoolerConfig
,
PrefixCachingHashAlgo
,
SchedulerConfig
,
SchedulerPolicy
,
SpeculativeConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SchedulerPolicy
,
TaskOption
,
TokenizerMode
,
VllmConfig
,
get_attr_docs
,
SpeculativeConfig
,
TaskOption
,
TokenizerMode
,
get_field
)
VllmConfig
,
get_attr_docs
,
get_field
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.plugins
import
load_general_plugins
from
vllm.plugins
import
load_general_plugins
...
@@ -324,6 +324,7 @@ class EngineArgs:
...
@@ -324,6 +324,7 @@ class EngineArgs:
SchedulerConfig
.
long_prefill_token_threshold
SchedulerConfig
.
long_prefill_token_threshold
max_num_seqs
:
Optional
[
int
]
=
SchedulerConfig
.
max_num_seqs
max_num_seqs
:
Optional
[
int
]
=
SchedulerConfig
.
max_num_seqs
max_logprobs
:
int
=
ModelConfig
.
max_logprobs
max_logprobs
:
int
=
ModelConfig
.
max_logprobs
logprobs_mode
:
LogprobsMode
=
ModelConfig
.
logprobs_mode
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
ModelConfig
.
revision
revision
:
Optional
[
str
]
=
ModelConfig
.
revision
code_revision
:
Optional
[
str
]
=
ModelConfig
.
code_revision
code_revision
:
Optional
[
str
]
=
ModelConfig
.
code_revision
...
@@ -490,6 +491,8 @@ class EngineArgs:
...
@@ -490,6 +491,8 @@ class EngineArgs:
**
model_kwargs
[
"max_seq_len_to_capture"
])
**
model_kwargs
[
"max_seq_len_to_capture"
])
model_group
.
add_argument
(
"--max-logprobs"
,
model_group
.
add_argument
(
"--max-logprobs"
,
**
model_kwargs
[
"max_logprobs"
])
**
model_kwargs
[
"max_logprobs"
])
model_group
.
add_argument
(
"--logprobs-mode"
,
**
model_kwargs
[
"logprobs_mode"
])
model_group
.
add_argument
(
"--disable-sliding-window"
,
model_group
.
add_argument
(
"--disable-sliding-window"
,
**
model_kwargs
[
"disable_sliding_window"
])
**
model_kwargs
[
"disable_sliding_window"
])
model_group
.
add_argument
(
"--disable-cascade-attn"
,
model_group
.
add_argument
(
"--disable-cascade-attn"
,
...
@@ -892,6 +895,7 @@ class EngineArgs:
...
@@ -892,6 +895,7 @@ class EngineArgs:
enforce_eager
=
self
.
enforce_eager
,
enforce_eager
=
self
.
enforce_eager
,
max_seq_len_to_capture
=
self
.
max_seq_len_to_capture
,
max_seq_len_to_capture
=
self
.
max_seq_len_to_capture
,
max_logprobs
=
self
.
max_logprobs
,
max_logprobs
=
self
.
max_logprobs
,
logprobs_mode
=
self
.
logprobs_mode
,
disable_sliding_window
=
self
.
disable_sliding_window
,
disable_sliding_window
=
self
.
disable_sliding_window
,
disable_cascade_attn
=
self
.
disable_cascade_attn
,
disable_cascade_attn
=
self
.
disable_cascade_attn
,
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
...
...
vllm/v1/sample/sampler.py
View file @
accac829
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
LogprobsMode
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
from
vllm.v1.outputs
import
LogprobsTensors
,
SamplerOutput
from
vllm.v1.outputs
import
LogprobsTensors
,
SamplerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
@@ -18,10 +19,11 @@ _SAMPLING_EPS = 1e-5
...
@@ -18,10 +19,11 @@ _SAMPLING_EPS = 1e-5
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
,
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
):
super
().
__init__
()
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
logprobs_mode
=
logprobs_mode
def
forward
(
def
forward
(
self
,
self
,
...
@@ -36,7 +38,10 @@ class Sampler(nn.Module):
...
@@ -36,7 +38,10 @@ class Sampler(nn.Module):
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
num_logprobs
=
sampling_metadata
.
max_num_logprobs
num_logprobs
=
sampling_metadata
.
max_num_logprobs
if
num_logprobs
is
not
None
:
if
num_logprobs
is
not
None
:
raw_logprobs
=
self
.
compute_logprobs
(
logits
)
if
self
.
logprobs_mode
==
"raw_logprobs"
:
raw_logprobs
=
self
.
compute_logprobs
(
logits
)
elif
self
.
logprobs_mode
==
"raw_logits"
:
raw_logprobs
=
logits
.
clone
()
# Use float32 for the logits.
# Use float32 for the logits.
logits
=
logits
.
to
(
torch
.
float32
)
logits
=
logits
.
to
(
torch
.
float32
)
...
@@ -51,6 +56,14 @@ class Sampler(nn.Module):
...
@@ -51,6 +56,14 @@ class Sampler(nn.Module):
# Apply penalties (e.g., min_tokens, freq_penalties).
# Apply penalties (e.g., min_tokens, freq_penalties).
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
)
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
)
# Get the process logprobs or logits.
if
num_logprobs
is
not
None
:
if
self
.
logprobs_mode
==
"processed_logprobs"
:
raw_logprobs
=
self
.
compute_logprobs
(
logits
)
elif
self
.
logprobs_mode
==
"processed_logits"
:
raw_logprobs
=
logits
.
clone
()
# Sample the next token.
# Sample the next token.
sampled
=
self
.
sample
(
logits
,
sampling_metadata
)
sampled
=
self
.
sample
(
logits
,
sampling_metadata
)
# Convert sampled token ids to int64 (long) type to ensure compatibility
# Convert sampled token ids to int64 (long) type to ensure compatibility
...
...
vllm/v1/sample/tpu/sampler.py
View file @
accac829
...
@@ -15,6 +15,7 @@ _SAMPLING_EPS = 1e-5
...
@@ -15,6 +15,7 @@ _SAMPLING_EPS = 1e-5
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
# TODO(houseroad): Add support for logprobs_mode.
super
().
__init__
()
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
accac829
...
@@ -389,7 +389,7 @@ class InputBatch:
...
@@ -389,7 +389,7 @@ class InputBatch:
def
remove_request
(
self
,
req_id
:
str
)
->
Optional
[
int
]:
def
remove_request
(
self
,
req_id
:
str
)
->
Optional
[
int
]:
"""This method must always be followed by a call to condense().
"""This method must always be followed by a call to condense().
Args:
Args:
req_id: request to remove
req_id: request to remove
...
@@ -590,7 +590,7 @@ class InputBatch:
...
@@ -590,7 +590,7 @@ class InputBatch:
def
refresh_metadata
(
self
):
def
refresh_metadata
(
self
):
"""Apply batch updates, reset input batch at end of step
"""Apply batch updates, reset input batch at end of step
* Apply batch add/remove/permute to logits procs' states
* Apply batch add/remove/permute to logits procs' states
* If batch state is modified, update sampling metadata
* If batch state is modified, update sampling metadata
"""
"""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
accac829
...
@@ -151,7 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -151,7 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
encoder_cache_size
=
encoder_cache_size
self
.
encoder_cache_size
=
encoder_cache_size
# Sampler
# Sampler
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
self
.
eplb_state
:
Optional
[
EplbState
]
=
None
self
.
eplb_state
:
Optional
[
EplbState
]
=
None
"""
"""
...
@@ -1996,7 +1996,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1996,7 +1996,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection
This is to help balance expert-selection
- during profile_run
- during profile_run
- during DP rank dummy run
- during DP rank dummy run
"""
"""
dp_size
=
self
.
vllm_config
.
parallel_config
.
data_parallel_size
dp_size
=
self
.
vllm_config
.
parallel_config
.
data_parallel_size
randomize_inputs
=
envs
.
VLLM_RANDOMIZE_DP_DUMMY_INPUTS
and
dp_size
>
1
randomize_inputs
=
envs
.
VLLM_RANDOMIZE_DP_DUMMY_INPUTS
and
dp_size
>
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment