Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d706cca
Unverified
Commit
8d706cca
authored
Nov 10, 2025
by
Zhuohan Li
Committed by
GitHub
Nov 11, 2025
Browse files
[Misc] FlattenLogprobs -> FlatLogprobs (#28335)
parent
57201a6a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
47 deletions
+43
-47
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+6
-10
tests/test_logprobs.py
tests/test_logprobs.py
+20
-20
vllm/envs.py
vllm/envs.py
+4
-4
vllm/logprobs.py
vllm/logprobs.py
+13
-13
No files found.
tests/samplers/test_logprobs.py
View file @
8d706cca
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
pytest
import
pytest
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.logprobs
import
Flat
ten
Logprobs
from
vllm.logprobs
import
FlatLogprobs
MODELS
=
[
"distilbert/distilgpt2"
]
MODELS
=
[
"distilbert/distilgpt2"
]
MAX_TOKENS
=
5
MAX_TOKENS
=
5
...
@@ -16,17 +16,17 @@ MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
...
@@ -16,17 +16,17 @@ MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"greedy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"greedy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"flat
ten
_logprobs"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"flat_logprobs"
,
[
True
,
False
])
def
test_ranks
(
def
test_ranks
(
vllm_runner
,
vllm_runner
,
model
,
model
,
dtype
,
dtype
,
greedy
,
greedy
,
flat
ten
_logprobs
,
flat_logprobs
,
example_prompts
,
example_prompts
,
monkeypatch
:
pytest
.
MonkeyPatch
,
monkeypatch
:
pytest
.
MonkeyPatch
,
):
):
monkeypatch
.
setenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"1"
if
flat
ten
_logprobs
else
"0"
)
monkeypatch
.
setenv
(
"VLLM_FLAT_LOGPROBS"
,
"1"
if
flat_logprobs
else
"0"
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
max_logprobs
=
MAX_LOGPROBS
)
as
vllm_model
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
max_logprobs
=
MAX_LOGPROBS
)
as
vllm_model
:
tokenizer
=
vllm_model
.
llm
.
get_tokenizer
()
tokenizer
=
vllm_model
.
llm
.
get_tokenizer
()
example_prompt_tokens
=
[
tokenizer
.
encode
(
prompt
)
for
prompt
in
example_prompts
]
example_prompt_tokens
=
[
tokenizer
.
encode
(
prompt
)
for
prompt
in
example_prompts
]
...
@@ -44,12 +44,8 @@ def test_ranks(
...
@@ -44,12 +44,8 @@ def test_ranks(
decode_tokens
,
_
,
decode_logprobs
,
prompt_logprobs
=
result
decode_tokens
,
_
,
decode_logprobs
,
prompt_logprobs
=
result
# Ensure the return type of logprobs is accurate
# Ensure the return type of logprobs is accurate
assert
isinstance
(
assert
isinstance
(
prompt_logprobs
,
FlatLogprobs
if
flat_logprobs
else
list
)
prompt_logprobs
,
FlattenLogprobs
if
flatten_logprobs
else
list
assert
isinstance
(
decode_logprobs
,
FlatLogprobs
if
flat_logprobs
else
list
)
)
assert
isinstance
(
decode_logprobs
,
FlattenLogprobs
if
flatten_logprobs
else
list
)
########################
########################
# Check prompt logprobs
# Check prompt logprobs
...
...
tests/test_logprobs.py
View file @
8d706cca
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
pytest
import
pytest
from
vllm.logprobs
import
(
from
vllm.logprobs
import
(
Flat
ten
Logprobs
,
FlatLogprobs
,
Logprob
,
Logprob
,
LogprobsOnePosition
,
LogprobsOnePosition
,
append_logprobs_for_next_position
,
append_logprobs_for_next_position
,
...
@@ -14,8 +14,8 @@ from vllm.logprobs import (
...
@@ -14,8 +14,8 @@ from vllm.logprobs import (
)
)
def
test_create_logprobs_non_flat
ten
(
monkeypatch
:
pytest
.
MonkeyPatch
)
->
None
:
def
test_create_logprobs_non_flat
(
monkeypatch
:
pytest
.
MonkeyPatch
)
->
None
:
monkeypatch
.
setenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_FLAT_LOGPROBS"
,
"0"
)
prompt_logprobs
=
create_prompt_logprobs
()
prompt_logprobs
=
create_prompt_logprobs
()
assert
isinstance
(
prompt_logprobs
,
list
)
assert
isinstance
(
prompt_logprobs
,
list
)
...
@@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
...
@@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert
len
(
sample_logprobs
)
==
0
assert
len
(
sample_logprobs
)
==
0
def
test_create_logprobs_flat
ten
(
monkeypatch
:
pytest
.
MonkeyPatch
)
->
None
:
def
test_create_logprobs_flat
(
monkeypatch
:
pytest
.
MonkeyPatch
)
->
None
:
monkeypatch
.
setenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_FLAT_LOGPROBS"
,
"1"
)
prompt_logprobs
=
create_prompt_logprobs
()
prompt_logprobs
=
create_prompt_logprobs
()
assert
isinstance
(
prompt_logprobs
,
Flat
ten
Logprobs
)
assert
isinstance
(
prompt_logprobs
,
FlatLogprobs
)
assert
prompt_logprobs
.
start_indices
==
[
0
]
assert
prompt_logprobs
.
start_indices
==
[
0
]
assert
prompt_logprobs
.
end_indices
==
[
0
]
assert
prompt_logprobs
.
end_indices
==
[
0
]
assert
len
(
prompt_logprobs
.
token_ids
)
==
0
assert
len
(
prompt_logprobs
.
token_ids
)
==
0
...
@@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
...
@@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert
prompt_logprobs
[
0
]
==
dict
()
assert
prompt_logprobs
[
0
]
==
dict
()
sample_logprobs
=
create_sample_logprobs
()
sample_logprobs
=
create_sample_logprobs
()
assert
isinstance
(
sample_logprobs
,
Flat
ten
Logprobs
)
assert
isinstance
(
sample_logprobs
,
FlatLogprobs
)
assert
len
(
sample_logprobs
.
start_indices
)
==
0
assert
len
(
sample_logprobs
.
start_indices
)
==
0
assert
len
(
sample_logprobs
.
end_indices
)
==
0
assert
len
(
sample_logprobs
.
end_indices
)
==
0
assert
len
(
sample_logprobs
.
token_ids
)
==
0
assert
len
(
sample_logprobs
.
token_ids
)
==
0
...
@@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
...
@@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert
len
(
sample_logprobs
)
==
0
assert
len
(
sample_logprobs
)
==
0
def
test_append_logprobs_for_next_position_none_flat
ten
(
def
test_append_logprobs_for_next_position_none_flat
(
monkeypatch
:
pytest
.
MonkeyPatch
,
monkeypatch
:
pytest
.
MonkeyPatch
,
)
->
None
:
)
->
None
:
monkeypatch
.
setenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_FLAT_LOGPROBS"
,
"0"
)
logprobs
=
create_sample_logprobs
()
logprobs
=
create_sample_logprobs
()
append_logprobs_for_next_position
(
append_logprobs_for_next_position
(
logprobs
,
logprobs
,
...
@@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten(
...
@@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten(
]
]
def
test_append_logprobs_for_next_position_flat
ten
(
def
test_append_logprobs_for_next_position_flat
(
monkeypatch
:
pytest
.
MonkeyPatch
,
monkeypatch
:
pytest
.
MonkeyPatch
,
)
->
None
:
)
->
None
:
monkeypatch
.
setenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_FLAT_LOGPROBS"
,
"1"
)
logprobs
=
create_sample_logprobs
()
logprobs
=
create_sample_logprobs
()
append_logprobs_for_next_position
(
append_logprobs_for_next_position
(
logprobs
,
logprobs
,
...
@@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten(
...
@@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten(
rank
=
11
,
rank
=
11
,
num_logprobs
=-
1
,
num_logprobs
=-
1
,
)
)
assert
isinstance
(
logprobs
,
Flat
ten
Logprobs
)
assert
isinstance
(
logprobs
,
FlatLogprobs
)
assert
logprobs
.
start_indices
==
[
0
,
1
]
assert
logprobs
.
start_indices
==
[
0
,
1
]
assert
logprobs
.
end_indices
==
[
1
,
3
]
assert
logprobs
.
end_indices
==
[
1
,
3
]
assert
logprobs
.
token_ids
==
[
1
,
2
,
3
]
assert
logprobs
.
token_ids
==
[
1
,
2
,
3
]
...
@@ -129,8 +129,8 @@ LOGPROBS_ONE_POSITION_2: LogprobsOnePosition = {
...
@@ -129,8 +129,8 @@ LOGPROBS_ONE_POSITION_2: LogprobsOnePosition = {
}
}
def
test_flat
ten
_logprobs_append
()
->
None
:
def
test_flat_logprobs_append
()
->
None
:
logprobs
=
Flat
ten
Logprobs
()
logprobs
=
FlatLogprobs
()
logprobs
.
append
(
LOGPROBS_ONE_POSITION_0
)
logprobs
.
append
(
LOGPROBS_ONE_POSITION_0
)
logprobs
.
append
(
LOGPROBS_ONE_POSITION_1
)
logprobs
.
append
(
LOGPROBS_ONE_POSITION_1
)
assert
logprobs
.
start_indices
==
[
0
,
1
]
assert
logprobs
.
start_indices
==
[
0
,
1
]
...
@@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None:
...
@@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None:
assert
logprobs
.
decoded_tokens
==
[
"10"
,
"20"
,
"30"
,
"40"
,
"50"
,
"60"
]
assert
logprobs
.
decoded_tokens
==
[
"10"
,
"20"
,
"30"
,
"40"
,
"50"
,
"60"
]
def
test_flat
ten
_logprobs_extend
()
->
None
:
def
test_flat_logprobs_extend
()
->
None
:
logprobs
=
Flat
ten
Logprobs
()
logprobs
=
FlatLogprobs
()
# Extend with list[LogprobsOnePosition]
# Extend with list[LogprobsOnePosition]
logprobs
.
extend
([
LOGPROBS_ONE_POSITION_2
,
LOGPROBS_ONE_POSITION_0
])
logprobs
.
extend
([
LOGPROBS_ONE_POSITION_2
,
LOGPROBS_ONE_POSITION_0
])
assert
logprobs
.
start_indices
==
[
0
,
3
]
assert
logprobs
.
start_indices
==
[
0
,
3
]
...
@@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None:
...
@@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None:
assert
logprobs
.
ranks
==
[
40
,
50
,
60
,
10
]
assert
logprobs
.
ranks
==
[
40
,
50
,
60
,
10
]
assert
logprobs
.
decoded_tokens
==
[
"40"
,
"50"
,
"60"
,
"10"
]
assert
logprobs
.
decoded_tokens
==
[
"40"
,
"50"
,
"60"
,
"10"
]
other_logprobs
=
Flat
ten
Logprobs
()
other_logprobs
=
FlatLogprobs
()
other_logprobs
.
extend
([
LOGPROBS_ONE_POSITION_1
,
LOGPROBS_ONE_POSITION_0
])
other_logprobs
.
extend
([
LOGPROBS_ONE_POSITION_1
,
LOGPROBS_ONE_POSITION_0
])
# Extend with another Flat
ten
Logprobs
# Extend with another FlatLogprobs
logprobs
.
extend
(
other_logprobs
)
logprobs
.
extend
(
other_logprobs
)
assert
logprobs
.
start_indices
==
[
0
,
3
,
4
,
6
]
assert
logprobs
.
start_indices
==
[
0
,
3
,
4
,
6
]
assert
logprobs
.
end_indices
==
[
3
,
4
,
6
,
7
]
assert
logprobs
.
end_indices
==
[
3
,
4
,
6
,
7
]
...
@@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None:
...
@@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None:
assert
logprobs
.
decoded_tokens
==
[
"40"
,
"50"
,
"60"
,
"10"
,
"20"
,
"30"
,
"10"
]
assert
logprobs
.
decoded_tokens
==
[
"40"
,
"50"
,
"60"
,
"10"
,
"20"
,
"30"
,
"10"
]
def
test_flat
ten
_logprobs_access
()
->
None
:
def
test_flat_logprobs_access
()
->
None
:
logprobs
=
Flat
ten
Logprobs
()
logprobs
=
FlatLogprobs
()
logprobs
.
extend
(
logprobs
.
extend
(
[
LOGPROBS_ONE_POSITION_1
,
LOGPROBS_ONE_POSITION_2
,
LOGPROBS_ONE_POSITION_0
]
[
LOGPROBS_ONE_POSITION_1
,
LOGPROBS_ONE_POSITION_2
,
LOGPROBS_ONE_POSITION_0
]
)
)
...
...
vllm/envs.py
View file @
8d706cca
...
@@ -223,7 +223,7 @@ if TYPE_CHECKING:
...
@@ -223,7 +223,7 @@ if TYPE_CHECKING:
VLLM_GC_DEBUG
:
str
=
""
VLLM_GC_DEBUG
:
str
=
""
VLLM_DISABLE_SHARED_EXPERTS_STREAM
:
bool
=
False
VLLM_DISABLE_SHARED_EXPERTS_STREAM
:
bool
=
False
VLLM_COMPILE_CACHE_SAVE_FORMAT
:
Literal
[
"binary"
,
"unpacked"
]
=
"binary"
VLLM_COMPILE_CACHE_SAVE_FORMAT
:
Literal
[
"binary"
,
"unpacked"
]
=
"binary"
VLLM_FLAT
TEN
_LOGPROBS
:
bool
=
False
VLLM_FLAT_LOGPROBS
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -1481,11 +1481,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1481,11 +1481,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_COMPILE_CACHE_SAVE_FORMAT"
:
env_with_choices
(
"VLLM_COMPILE_CACHE_SAVE_FORMAT"
:
env_with_choices
(
"VLLM_COMPILE_CACHE_SAVE_FORMAT"
,
"binary"
,
[
"binary"
,
"unpacked"
]
"VLLM_COMPILE_CACHE_SAVE_FORMAT"
,
"binary"
,
[
"binary"
,
"unpacked"
]
),
),
# Flag to enable Flat
ten
Logprobs whose GC overhead is significantly smaller than
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
# the original list[dict[int, Logprob]] approach.
# the original list[dict[int, Logprob]] approach.
# After enabled, PromptLogprobs and SampleLogprobs would populated as
# After enabled, PromptLogprobs and SampleLogprobs would populated as
# Flat
ten
Logprobs.
# FlatLogprobs.
"VLLM_FLAT
TEN
_LOGPROBS"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_FLAT
TEN
_LOGPROBS"
,
"0"
))),
"VLLM_FLAT_LOGPROBS"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_FLAT_LOGPROBS"
,
"0"
))),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/logprobs.py
View file @
8d706cca
...
@@ -30,16 +30,16 @@ LogprobsOnePosition = dict[int, Logprob]
...
@@ -30,16 +30,16 @@ LogprobsOnePosition = dict[int, Logprob]
@
dataclass
@
dataclass
class
Flat
ten
Logprobs
(
MutableSequence
[
LogprobsOnePosition
]):
class
FlatLogprobs
(
MutableSequence
[
LogprobsOnePosition
]):
"""
"""
Flat
ten
logprobs of a request into multiple primitive type lists.
Flat logprobs of a request into multiple primitive type lists.
Compared to list[dict[int, Logprob]], this data structure reduced GC
Compared to list[dict[int, Logprob]], this data structure reduced GC
overhead significantly. As it flattened logprob information for
overhead significantly. As it flattened logprob information for
all positions and ranks in to multiple primitive type lists (i.e.
all positions and ranks in to multiple primitive type lists (i.e.
logprobs, token_ids, ranks per token_ids, decoded_tokens).
logprobs, token_ids, ranks per token_ids, decoded_tokens).
So regardless of the sequence length and top_logprobs setup,
So regardless of the sequence length and top_logprobs setup,
Flat
ten
Logprobs would only introduce a constant amount of objects.
FlatLogprobs would only introduce a constant amount of objects.
As each position might contains different amount of ranks,
As each position might contains different amount of ranks,
start_indices_per_position would be used to access the logprob ranges
start_indices_per_position would be used to access the logprob ranges
...
@@ -107,7 +107,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
...
@@ -107,7 +107,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
def
__getitem__
(
self
,
position
:
int
)
->
LogprobsOnePosition
:
...
def
__getitem__
(
self
,
position
:
int
)
->
LogprobsOnePosition
:
...
@
overload
@
overload
def
__getitem__
(
self
,
s
:
slice
,
/
)
->
"Flat
ten
Logprobs"
:
...
def
__getitem__
(
self
,
s
:
slice
,
/
)
->
"FlatLogprobs"
:
...
def
__getitem__
(
self
,
index
:
int
|
slice
):
def
__getitem__
(
self
,
index
:
int
|
slice
):
"""Extracts logprobs of a given position or slice"""
"""Extracts logprobs of a given position or slice"""
...
@@ -123,7 +123,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
...
@@ -123,7 +123,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
elif
isinstance
(
index
,
slice
):
elif
isinstance
(
index
,
slice
):
min_index
=
self
.
start_indices
[
index
][
0
]
min_index
=
self
.
start_indices
[
index
][
0
]
max_index
=
self
.
end_indices
[
index
][
-
1
]
max_index
=
self
.
end_indices
[
index
][
-
1
]
return
Flat
ten
Logprobs
(
return
FlatLogprobs
(
# Shift updated start_indices and end_indices to
# Shift updated start_indices and end_indices to
# be 0-indexed
# be 0-indexed
start_indices
=
[
i
-
min_index
for
i
in
self
.
start_indices
[
index
]],
start_indices
=
[
i
-
min_index
for
i
in
self
.
start_indices
[
index
]],
...
@@ -137,13 +137,13 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
...
@@ -137,13 +137,13 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
raise
TypeError
(
f
"Invalid index type:
{
type
(
index
)
}
"
)
raise
TypeError
(
f
"Invalid index type:
{
type
(
index
)
}
"
)
def
__setitem__
(
self
,
item
,
value
)
->
None
:
def
__setitem__
(
self
,
item
,
value
)
->
None
:
raise
TypeError
(
"Cannot set logprobs in Flat
ten
Logprobs"
)
raise
TypeError
(
"Cannot set logprobs in FlatLogprobs"
)
def
__delitem__
(
self
,
item
)
->
None
:
def
__delitem__
(
self
,
item
)
->
None
:
raise
TypeError
(
"Cannot delete logprobs from Flat
ten
Logprobs"
)
raise
TypeError
(
"Cannot delete logprobs from FlatLogprobs"
)
def
insert
(
self
,
item
)
->
None
:
def
insert
(
self
,
item
)
->
None
:
raise
TypeError
(
"Cannot insert logprobs to Flat
ten
Logprobs"
)
raise
TypeError
(
"Cannot insert logprobs to FlatLogprobs"
)
def
__iter__
(
self
)
->
Iterator
[
LogprobsOnePosition
]:
def
__iter__
(
self
)
->
Iterator
[
LogprobsOnePosition
]:
"""
"""
...
@@ -156,14 +156,14 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
...
@@ -156,14 +156,14 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
# {token_id -> logprob} per each sequence group. None if the corresponding
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
# sequence group doesn't require prompt logprob.
PromptLogprobs
=
Flat
ten
Logprobs
|
list
[
LogprobsOnePosition
|
None
]
PromptLogprobs
=
FlatLogprobs
|
list
[
LogprobsOnePosition
|
None
]
# {token_id -> logprob} for each sequence group.
# {token_id -> logprob} for each sequence group.
SampleLogprobs
=
Flat
ten
Logprobs
|
list
[
LogprobsOnePosition
]
SampleLogprobs
=
FlatLogprobs
|
list
[
LogprobsOnePosition
]
def
create_prompt_logprobs
()
->
PromptLogprobs
:
def
create_prompt_logprobs
()
->
PromptLogprobs
:
"""Creates a container to store prompt logprobs for a request"""
"""Creates a container to store prompt logprobs for a request"""
logprobs
=
Flat
ten
Logprobs
()
if
envs
.
VLLM_FLAT
TEN
_LOGPROBS
else
[]
logprobs
=
FlatLogprobs
()
if
envs
.
VLLM_FLAT_LOGPROBS
else
[]
# NOTE: logprob of first prompt token is None.
# NOTE: logprob of first prompt token is None.
logprobs
.
append
(
None
)
logprobs
.
append
(
None
)
return
logprobs
return
logprobs
...
@@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs:
...
@@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs:
def
create_sample_logprobs
()
->
SampleLogprobs
:
def
create_sample_logprobs
()
->
SampleLogprobs
:
"""Creates a container to store decode logprobs for a request"""
"""Creates a container to store decode logprobs for a request"""
return
Flat
ten
Logprobs
()
if
envs
.
VLLM_FLAT
TEN
_LOGPROBS
else
[]
return
FlatLogprobs
()
if
envs
.
VLLM_FLAT_LOGPROBS
else
[]
def
append_logprobs_for_next_position
(
def
append_logprobs_for_next_position
(
...
@@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
...
@@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
topk_ranks
=
range
(
1
,
num_logprobs
+
1
)
topk_ranks
=
range
(
1
,
num_logprobs
+
1
)
ranks
=
itertools
.
chain
((
rank
,),
topk_ranks
)
ranks
=
itertools
.
chain
((
rank
,),
topk_ranks
)
if
isinstance
(
request_logprobs
,
Flat
ten
Logprobs
):
if
isinstance
(
request_logprobs
,
FlatLogprobs
):
request_logprobs
.
append_fast
(
token_ids
,
logprobs
,
ranks
,
decoded_tokens
)
request_logprobs
.
append_fast
(
token_ids
,
logprobs
,
ranks
,
decoded_tokens
)
else
:
else
:
request_logprobs
.
append
(
request_logprobs
.
append
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment