Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
974fc9b8
Unverified
Commit
974fc9b8
authored
Jun 04, 2024
by
zifeitong
Committed by
GitHub
Jun 04, 2024
Browse files
[Bugfix] Fix prompt_logprobs when SamplingParams.detokenize is set to True (#5226)
parent
fee4dcc3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
13 deletions
+22
-13
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+18
-9
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+4
-4
No files found.
tests/samplers/test_logprobs.py
View file @
974fc9b8
...
...
@@ -12,6 +12,7 @@ MODELS = ["facebook/opt-125m"]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"num_top_logprobs"
,
[
6
])
# 32000 == vocab_size
@
pytest
.
mark
.
parametrize
(
"detokenize"
,
[
True
,
False
])
def
test_get_prompt_logprobs
(
hf_runner
,
vllm_runner
,
...
...
@@ -19,6 +20,7 @@ def test_get_prompt_logprobs(
dtype
,
chunked_prefill_token_size
:
int
,
num_top_logprobs
:
int
,
detokenize
:
bool
,
example_prompts
,
):
max_num_seqs
=
256
...
...
@@ -48,7 +50,8 @@ def test_get_prompt_logprobs(
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
logprobs
=
num_top_logprobs
,
prompt_logprobs
=
num_top_logprobs
,
temperature
=
0.0
)
temperature
=
0.0
,
detokenize
=
detokenize
)
vllm_results
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
...
...
@@ -65,11 +68,16 @@ def test_get_prompt_logprobs(
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
output_string_from_most_likely_tokens
.
append
(
top_logprob
.
decoded_token
)
if
detokenize
:
output_string_from_most_likely_tokens
=
""
.
join
(
output_string_from_most_likely_tokens
)
assert
output_text
==
output_string_from_most_likely_tokens
,
(
"The output text from the top logprob for each token position "
"should be the same as the output text in the result."
)
else
:
assert
output_text
==
''
assert
output_string_from_most_likely_tokens
==
[
None
]
*
max_tokens
# The first prompt logprob is always None
assert
result
.
prompt_logprobs
[
0
]
is
None
...
...
@@ -98,8 +106,9 @@ def test_get_prompt_logprobs(
hf_logprob
[
i
][
-
1
][
token_id
].
item
(),
atol
=
1e-2
,
rtol
=
1e-2
)
if
detokenize
:
assert
isinstance
(
sample_logprob
.
decoded_token
,
str
),
(
"The token should be decoded by the time it is returned
"
"The token should be decoded by the time it is returned"
" to the user."
)
# Test if prompt logprobs are correctly set.
...
...
vllm/engine/output_processor/single_step.py
View file @
974fc9b8
...
...
@@ -60,8 +60,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
assert
len
(
outputs
)
==
1
,
(
"Single step should only has 1 output."
)
output
=
outputs
[
0
]
prompt_logprobs
=
output
.
prompt_logprobs
if
(
prompt_logprobs
is
not
None
and
seq_group
.
sampling_params
.
detokenize
and
self
.
detokenizer
)
:
if
prompt_logprobs
is
not
None
:
if
seq_group
.
sampling_params
.
detokenize
and
self
.
detokenizer
:
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
prompt_logprobs
)
if
not
seq_group
.
prompt_logprobs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment