Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d706cca
Unverified
Commit
8d706cca
authored
Nov 10, 2025
by
Zhuohan Li
Committed by
GitHub
Nov 11, 2025
Browse files
[Misc] FlattenLogprobs -> FlatLogprobs (#28335)
parent
57201a6a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
47 deletions
+43
-47
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+6
-10
tests/test_logprobs.py
tests/test_logprobs.py
+20
-20
vllm/envs.py
vllm/envs.py
+4
-4
vllm/logprobs.py
vllm/logprobs.py
+13
-13
No files found.
tests/samplers/test_logprobs.py
View file @
8d706cca
...
...
@@ -4,7 +4,7 @@
import
pytest
from
vllm
import
SamplingParams
from
vllm.logprobs
import
Flat
ten
Logprobs
from
vllm.logprobs
import
FlatLogprobs
MODELS
=
[
"distilbert/distilgpt2"
]
MAX_TOKENS
=
5
...
...
@@ -16,17 +16,17 @@ MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"greedy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"flat
ten
_logprobs"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"flat_logprobs"
,
[
True
,
False
])
def
test_ranks
(
vllm_runner
,
model
,
dtype
,
greedy
,
flat
ten
_logprobs
,
flat_logprobs
,
example_prompts
,
monkeypatch
:
pytest
.
MonkeyPatch
,
):
monkeypatch
.
setenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"1"
if
flat
ten
_logprobs
else
"0"
)
monkeypatch
.
setenv
(
"VLLM_FLAT_LOGPROBS"
,
"1"
if
flat_logprobs
else
"0"
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
max_logprobs
=
MAX_LOGPROBS
)
as
vllm_model
:
tokenizer
=
vllm_model
.
llm
.
get_tokenizer
()
example_prompt_tokens
=
[
tokenizer
.
encode
(
prompt
)
for
prompt
in
example_prompts
]
...
...
@@ -44,12 +44,8 @@ def test_ranks(
decode_tokens
,
_
,
decode_logprobs
,
prompt_logprobs
=
result
# Ensure the return type of logprobs is accurate
assert
isinstance
(
prompt_logprobs
,
FlattenLogprobs
if
flatten_logprobs
else
list
)
assert
isinstance
(
decode_logprobs
,
FlattenLogprobs
if
flatten_logprobs
else
list
)
assert
isinstance
(
prompt_logprobs
,
FlatLogprobs
if
flat_logprobs
else
list
)
assert
isinstance
(
decode_logprobs
,
FlatLogprobs
if
flat_logprobs
else
list
)
########################
# Check prompt logprobs
...
...
tests/test_logprobs.py
View file @
8d706cca
...
...
@@ -5,7 +5,7 @@
import
pytest
from
vllm.logprobs
import
(
Flat
ten
Logprobs
,
FlatLogprobs
,
Logprob
,
LogprobsOnePosition
,
append_logprobs_for_next_position
,
...
...
@@ -14,8 +14,8 @@ from vllm.logprobs import (
)
def
test_create_logprobs_non_flat
ten
(
monkeypatch
:
pytest
.
MonkeyPatch
)
->
None
:
monkeypatch
.
setenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"0"
)
def
test_create_logprobs_non_flat
(
monkeypatch
:
pytest
.
MonkeyPatch
)
->
None
:
monkeypatch
.
setenv
(
"VLLM_FLAT_LOGPROBS"
,
"0"
)
prompt_logprobs
=
create_prompt_logprobs
()
assert
isinstance
(
prompt_logprobs
,
list
)
...
...
@@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert
len
(
sample_logprobs
)
==
0
def
test_create_logprobs_flat
ten
(
monkeypatch
:
pytest
.
MonkeyPatch
)
->
None
:
monkeypatch
.
setenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"1"
)
def
test_create_logprobs_flat
(
monkeypatch
:
pytest
.
MonkeyPatch
)
->
None
:
monkeypatch
.
setenv
(
"VLLM_FLAT_LOGPROBS"
,
"1"
)
prompt_logprobs
=
create_prompt_logprobs
()
assert
isinstance
(
prompt_logprobs
,
Flat
ten
Logprobs
)
assert
isinstance
(
prompt_logprobs
,
FlatLogprobs
)
assert
prompt_logprobs
.
start_indices
==
[
0
]
assert
prompt_logprobs
.
end_indices
==
[
0
]
assert
len
(
prompt_logprobs
.
token_ids
)
==
0
...
...
@@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert
prompt_logprobs
[
0
]
==
dict
()
sample_logprobs
=
create_sample_logprobs
()
assert
isinstance
(
sample_logprobs
,
Flat
ten
Logprobs
)
assert
isinstance
(
sample_logprobs
,
FlatLogprobs
)
assert
len
(
sample_logprobs
.
start_indices
)
==
0
assert
len
(
sample_logprobs
.
end_indices
)
==
0
assert
len
(
sample_logprobs
.
token_ids
)
==
0
...
...
@@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert
len
(
sample_logprobs
)
==
0
def
test_append_logprobs_for_next_position_none_flat
ten
(
def
test_append_logprobs_for_next_position_none_flat
(
monkeypatch
:
pytest
.
MonkeyPatch
,
)
->
None
:
monkeypatch
.
setenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_FLAT_LOGPROBS"
,
"0"
)
logprobs
=
create_sample_logprobs
()
append_logprobs_for_next_position
(
logprobs
,
...
...
@@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten(
]
def
test_append_logprobs_for_next_position_flat
ten
(
def
test_append_logprobs_for_next_position_flat
(
monkeypatch
:
pytest
.
MonkeyPatch
,
)
->
None
:
monkeypatch
.
setenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_FLAT_LOGPROBS"
,
"1"
)
logprobs
=
create_sample_logprobs
()
append_logprobs_for_next_position
(
logprobs
,
...
...
@@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten(
rank
=
11
,
num_logprobs
=-
1
,
)
assert
isinstance
(
logprobs
,
Flat
ten
Logprobs
)
assert
isinstance
(
logprobs
,
FlatLogprobs
)
assert
logprobs
.
start_indices
==
[
0
,
1
]
assert
logprobs
.
end_indices
==
[
1
,
3
]
assert
logprobs
.
token_ids
==
[
1
,
2
,
3
]
...
...
@@ -129,8 +129,8 @@ LOGPROBS_ONE_POSITION_2: LogprobsOnePosition = {
}
def
test_flat
ten
_logprobs_append
()
->
None
:
logprobs
=
Flat
ten
Logprobs
()
def
test_flat_logprobs_append
()
->
None
:
logprobs
=
FlatLogprobs
()
logprobs
.
append
(
LOGPROBS_ONE_POSITION_0
)
logprobs
.
append
(
LOGPROBS_ONE_POSITION_1
)
assert
logprobs
.
start_indices
==
[
0
,
1
]
...
...
@@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None:
assert
logprobs
.
decoded_tokens
==
[
"10"
,
"20"
,
"30"
,
"40"
,
"50"
,
"60"
]
def
test_flat
ten
_logprobs_extend
()
->
None
:
logprobs
=
Flat
ten
Logprobs
()
def
test_flat_logprobs_extend
()
->
None
:
logprobs
=
FlatLogprobs
()
# Extend with list[LogprobsOnePosition]
logprobs
.
extend
([
LOGPROBS_ONE_POSITION_2
,
LOGPROBS_ONE_POSITION_0
])
assert
logprobs
.
start_indices
==
[
0
,
3
]
...
...
@@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None:
assert
logprobs
.
ranks
==
[
40
,
50
,
60
,
10
]
assert
logprobs
.
decoded_tokens
==
[
"40"
,
"50"
,
"60"
,
"10"
]
other_logprobs
=
Flat
ten
Logprobs
()
other_logprobs
=
FlatLogprobs
()
other_logprobs
.
extend
([
LOGPROBS_ONE_POSITION_1
,
LOGPROBS_ONE_POSITION_0
])
# Extend with another Flat
ten
Logprobs
# Extend with another FlatLogprobs
logprobs
.
extend
(
other_logprobs
)
assert
logprobs
.
start_indices
==
[
0
,
3
,
4
,
6
]
assert
logprobs
.
end_indices
==
[
3
,
4
,
6
,
7
]
...
...
@@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None:
assert
logprobs
.
decoded_tokens
==
[
"40"
,
"50"
,
"60"
,
"10"
,
"20"
,
"30"
,
"10"
]
def
test_flat
ten
_logprobs_access
()
->
None
:
logprobs
=
Flat
ten
Logprobs
()
def
test_flat_logprobs_access
()
->
None
:
logprobs
=
FlatLogprobs
()
logprobs
.
extend
(
[
LOGPROBS_ONE_POSITION_1
,
LOGPROBS_ONE_POSITION_2
,
LOGPROBS_ONE_POSITION_0
]
)
...
...
vllm/envs.py
View file @
8d706cca
...
...
@@ -223,7 +223,7 @@ if TYPE_CHECKING:
VLLM_GC_DEBUG
:
str
=
""
VLLM_DISABLE_SHARED_EXPERTS_STREAM
:
bool
=
False
VLLM_COMPILE_CACHE_SAVE_FORMAT
:
Literal
[
"binary"
,
"unpacked"
]
=
"binary"
VLLM_FLAT
TEN
_LOGPROBS
:
bool
=
False
VLLM_FLAT_LOGPROBS
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -1481,11 +1481,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_COMPILE_CACHE_SAVE_FORMAT"
:
env_with_choices
(
"VLLM_COMPILE_CACHE_SAVE_FORMAT"
,
"binary"
,
[
"binary"
,
"unpacked"
]
),
# Flag to enable Flat
ten
Logprobs whose GC overhead is significantly smaller than
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
# the original list[dict[int, Logprob]] approach.
# After enabled, PromptLogprobs and SampleLogprobs would populated as
# Flat
ten
Logprobs.
"VLLM_FLAT
TEN
_LOGPROBS"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"0"
))),
# FlatLogprobs.
"VLLM_FLAT_LOGPROBS"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_FLAT_LOGPROBS"
,
"0"
))),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/logprobs.py
View file @
8d706cca
...
...
@@ -30,16 +30,16 @@ LogprobsOnePosition = dict[int, Logprob]
@
dataclass
class
Flat
ten
Logprobs
(
MutableSequence
[
LogprobsOnePosition
]):
class
FlatLogprobs
(
MutableSequence
[
LogprobsOnePosition
]):
"""
Flat
ten
logprobs of a request into multiple primitive type lists.
Flat logprobs of a request into multiple primitive type lists.
Compared to list[dict[int, Logprob]], this data structure reduced GC
overhead significantly. As it flattened logprob information for
all positions and ranks in to multiple primitive type lists (i.e.
logprobs, token_ids, ranks per token_ids, decoded_tokens).
So regardless of the sequence length and top_logprobs setup,
Flat
ten
Logprobs would only introduce a constant amount of objects.
FlatLogprobs would only introduce a constant amount of objects.
As each position might contains different amount of ranks,
start_indices_per_position would be used to access the logprob ranges
...
...
@@ -107,7 +107,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
def
__getitem__
(
self
,
position
:
int
)
->
LogprobsOnePosition
:
...
@
overload
def
__getitem__
(
self
,
s
:
slice
,
/
)
->
"Flat
ten
Logprobs"
:
...
def
__getitem__
(
self
,
s
:
slice
,
/
)
->
"FlatLogprobs"
:
...
def
__getitem__
(
self
,
index
:
int
|
slice
):
"""Extracts logprobs of a given position or slice"""
...
...
@@ -123,7 +123,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
elif
isinstance
(
index
,
slice
):
min_index
=
self
.
start_indices
[
index
][
0
]
max_index
=
self
.
end_indices
[
index
][
-
1
]
return
Flat
ten
Logprobs
(
return
FlatLogprobs
(
# Shift updated start_indices and end_indices to
# be 0-indexed
start_indices
=
[
i
-
min_index
for
i
in
self
.
start_indices
[
index
]],
...
...
@@ -137,13 +137,13 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
raise
TypeError
(
f
"Invalid index type:
{
type
(
index
)
}
"
)
def
__setitem__
(
self
,
item
,
value
)
->
None
:
raise
TypeError
(
"Cannot set logprobs in Flat
ten
Logprobs"
)
raise
TypeError
(
"Cannot set logprobs in FlatLogprobs"
)
def
__delitem__
(
self
,
item
)
->
None
:
raise
TypeError
(
"Cannot delete logprobs from Flat
ten
Logprobs"
)
raise
TypeError
(
"Cannot delete logprobs from FlatLogprobs"
)
def
insert
(
self
,
item
)
->
None
:
raise
TypeError
(
"Cannot insert logprobs to Flat
ten
Logprobs"
)
raise
TypeError
(
"Cannot insert logprobs to FlatLogprobs"
)
def
__iter__
(
self
)
->
Iterator
[
LogprobsOnePosition
]:
"""
...
...
@@ -156,14 +156,14 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs
=
Flat
ten
Logprobs
|
list
[
LogprobsOnePosition
|
None
]
PromptLogprobs
=
FlatLogprobs
|
list
[
LogprobsOnePosition
|
None
]
# {token_id -> logprob} for each sequence group.
SampleLogprobs
=
Flat
ten
Logprobs
|
list
[
LogprobsOnePosition
]
SampleLogprobs
=
FlatLogprobs
|
list
[
LogprobsOnePosition
]
def
create_prompt_logprobs
()
->
PromptLogprobs
:
"""Creates a container to store prompt logprobs for a request"""
logprobs
=
Flat
ten
Logprobs
()
if
envs
.
VLLM_FLAT
TEN
_LOGPROBS
else
[]
logprobs
=
FlatLogprobs
()
if
envs
.
VLLM_FLAT_LOGPROBS
else
[]
# NOTE: logprob of first prompt token is None.
logprobs
.
append
(
None
)
return
logprobs
...
...
@@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs:
def
create_sample_logprobs
()
->
SampleLogprobs
:
"""Creates a container to store decode logprobs for a request"""
return
Flat
ten
Logprobs
()
if
envs
.
VLLM_FLAT
TEN
_LOGPROBS
else
[]
return
FlatLogprobs
()
if
envs
.
VLLM_FLAT_LOGPROBS
else
[]
def
append_logprobs_for_next_position
(
...
...
@@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
topk_ranks
=
range
(
1
,
num_logprobs
+
1
)
ranks
=
itertools
.
chain
((
rank
,),
topk_ranks
)
if
isinstance
(
request_logprobs
,
Flat
ten
Logprobs
):
if
isinstance
(
request_logprobs
,
FlatLogprobs
):
request_logprobs
.
append_fast
(
token_ids
,
logprobs
,
ranks
,
decoded_tokens
)
else
:
request_logprobs
.
append
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment