Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7ed6a4f0
Unverified
Commit
7ed6a4f0
authored
Jul 11, 2024
by
Robert Shaw
Committed by
GitHub
Jul 11, 2024
Browse files
[ BugFix ] Prompt Logprobs Detokenization (#6223)
Co-authored-by:
Zifei Tong
<
zifeitong@gmail.com
>
parent
a4feba92
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
117 additions
and
32 deletions
+117
-32
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+4
-1
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+87
-22
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+14
-5
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+12
-4
No files found.
.buildkite/test-pipeline.yaml
View file @
7ed6a4f0
...
...
@@ -87,7 +87,10 @@ steps:
-
label
:
Engine Test
mirror_hardwares
:
[
amd
]
command
:
pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
commands
:
-
pytest -v -s engine test_sequence.py test_config.py test_logger.py
# OOM in the CI unless we run this separately
-
pytest -v -s tokenization
-
label
:
Entrypoints Test
mirror_hardwares
:
[
amd
]
...
...
tests/tokenization/test_detokenize.py
View file @
7ed6a4f0
from
typing
import
Dict
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
import
pytest
from
transformers
import
AutoTokenizer
...
...
@@ -139,6 +139,15 @@ def create_dummy_logprobs(
}
for
token_id
in
complete_sequence_token_ids
]
def
create_dummy_prompt_logprobs
(
complete_sequence_token_ids
:
List
[
int
]
)
->
List
[
Optional
[
Dict
[
int
,
Any
]]]:
# logprob for the first prompt token is None.
logprobs
:
List
[
Optional
[
Dict
[
int
,
Any
]]]
=
[
None
]
logprobs
.
extend
(
create_dummy_logprobs
(
complete_sequence_token_ids
)[
1
:])
return
logprobs
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
[
True
,
False
])
...
...
@@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
[
True
])
def
test_decode_prompt_logprobs
(
complete_sequence
:
str
,
complete_sequence_token_ids
:
List
[
int
],
detokenizer
:
Detokenizer
,
skip_special_tokens
:
bool
):
def
test_decode_prompt_logprobs
(
complete_sequence_token_ids
:
List
[
int
],
detokenizer
:
Detokenizer
):
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params
=
SamplingParams
(
skip_special_tokens
=
skip_special_tokens
,
sampling_params
=
SamplingParams
(
skip_special_tokens
=
True
,
prompt_logprobs
=
1
)
# Run sequentially.
...
...
@@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
seqs
=
[
seq
],
sampling_params
=
sampling_params
,
arrival_time
=
0.0
)
dummy_logprobs
=
create_dummy_logprobs
(
complete_sequence_token_ids
)
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
dummy_logprobs
)
decoded_prompt_logprobs
=
dummy_logprobs
dummy_logprobs
=
create_dummy_prompt_logprobs
(
complete_sequence_token_ids
)
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
dummy_logprobs
,
position_offset
=
0
)
# First logprob is None.
decoded_prompt_logprobs
:
List
[
Dict
[
int
,
Any
]]
=
dummy_logprobs
[
1
:]
# type: ignore
if
skip_special_tokens
:
# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that this will only be true if we skip
# special tokens.
assert
complete_sequence
==
""
.
join
([
logprobs
[
token_id
].
decoded_token
for
token_id
,
logprobs
in
zip
(
complete_sequence_token_ids
,
decoded_prompt_logprobs
)
])
assert
complete_sequence
!=
""
.
join
([
logprobs
[
token_id
+
1
].
decoded_token
for
token_id
,
logprobs
in
zip
(
complete_sequence_token_ids
,
decoded_prompt_logprobs
)
])
# decoded_prompt_logprobs doesn't contain the first token.
token_ids
=
complete_sequence_token_ids
tokenzier
=
detokenizer
.
get_tokenizer_for_seq
(
seq
)
text_full
=
tokenzier
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
text_first
=
tokenzier
.
decode
(
token_ids
[
0
],
skip_special_tokens
=
True
)
text
=
text_full
[
len
(
text_first
):]
# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that the first logprob is None.
assert
text
==
""
.
join
([
logprobs
[
token_id
].
decoded_token
for
token_id
,
logprobs
in
zip
(
token_ids
[
1
:],
decoded_prompt_logprobs
)
])
assert
text
!=
""
.
join
([
logprobs
[
token_id
+
1
].
decoded_token
for
token_id
,
logprobs
in
zip
(
token_ids
[
1
:],
decoded_prompt_logprobs
)
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/opt-125m"
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
7
,
16
,
-
1
])
def
test_decode_prompt_logprobs_chunked_prefill
(
vllm_runner
,
model
,
chunked_prefill_token_size
:
int
,
example_prompts
,
):
max_num_seqs
=
256
enable_chunked_prefill
=
False
max_num_batched_tokens
=
None
if
chunked_prefill_token_size
!=
-
1
:
enable_chunked_prefill
=
True
max_num_seqs
=
min
(
chunked_prefill_token_size
,
max_num_seqs
)
max_num_batched_tokens
=
chunked_prefill_token_size
with
vllm_runner
(
model
,
dtype
=
"half"
,
max_logprobs
=
5
,
gpu_memory_utilization
=
0.5
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_seqs
=
max_num_seqs
)
as
vllm_model
:
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
10
,
logprobs
=
5
,
prompt_logprobs
=
5
,
temperature
=
0.0
)
vllm_results
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
for
idx
,
result
in
enumerate
(
vllm_results
):
assert
result
.
prompt_logprobs
is
not
None
assert
result
.
prompt_logprobs
[
0
]
is
None
# Compared detokenized prompts ids to original prompt.
generated_string
=
""
for
(
prompt_token
,
prompt_logprobs
)
in
zip
(
result
.
prompt_token_ids
[
1
:],
result
.
prompt_logprobs
[
1
:]):
# prompt_logprobs is a dict of the token_id: logprob
# We select the token_id corresponding to the actual prompt
# Decoded token in the detokenized string corresponding to this
# prompt token.
generated_string
+=
prompt_logprobs
[
prompt_token
].
decoded_token
assert
generated_string
==
example_prompts
[
idx
],
(
"Detokenized prompt logprobs do not match original prompt"
)
vllm/engine/output_processor/single_step.py
View file @
7ed6a4f0
...
...
@@ -60,14 +60,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
assert
len
(
outputs
)
==
1
,
(
"Single step should only has 1 output."
)
output
=
outputs
[
0
]
prompt_logprobs
=
output
.
prompt_logprobs
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if
prompt_logprobs
is
not
None
:
if
not
seq_group
.
prompt_logprobs
:
prompt_logprobs
=
[
None
]
+
prompt_logprobs
seq_group
.
prompt_logprobs
=
[]
if
seq_group
.
sampling_params
.
detokenize
and
self
.
detokenizer
:
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
prompt_logprobs
)
if
not
seq_group
.
prompt_logprobs
:
# The first prompt token's logprob is None because it doesn't
# have tokens that are precedent.
seq_group
.
prompt_logprobs
=
[
None
]
seq_group
,
prompt_logprobs
,
position_offset
=
len
(
seq_group
.
prompt_logprobs
))
seq_group
.
prompt_logprobs
.
extend
(
prompt_logprobs
)
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
...
...
vllm/transformers_utils/detokenizer.py
View file @
7ed6a4f0
...
...
@@ -21,14 +21,17 @@ class Detokenizer:
"""Returns the HF tokenizer to use for a given sequence."""
return
self
.
tokenizer_group
.
get_lora_tokenizer
(
sequence
.
lora_request
)
def
decode_prompt_logprobs_inplace
(
self
,
seq_group
:
SequenceGroup
,
prompt_logprobs
:
List
[
Optional
[
Dict
[
int
,
Logprob
]]])
->
None
:
def
decode_prompt_logprobs_inplace
(
self
,
seq_group
:
SequenceGroup
,
prompt_logprobs
:
List
[
Optional
[
Dict
[
int
,
Logprob
]]],
position_offset
:
int
)
->
None
:
"""Decodes the logprobs for the prompt of a sequence group.
Args:
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
position_offset: Offset of the first index of the logprobs
relative to the start of the sequence (for chunked prefill).
Returns:
The prompt logprobs with the decoded tokens.
...
...
@@ -47,8 +50,13 @@ class Detokenizer:
next_iter_tokens
:
List
[
str
]
=
[]
prev_tokens
=
None
for
token_position
,
prompt_logprobs_for_token
in
enumerate
(
for
token_position
_in_logprob
,
prompt_logprobs_for_token
in
enumerate
(
prompt_logprobs
):
# Absolute token position equals the index in the logprobs
# list plus the offset of the entire logprobs list relative
# to the start of the sequence.
token_position
=
token_position_in_logprob
+
position_offset
if
not
prompt_logprobs_for_token
:
continue
for
token_id
,
sample_logprob
in
prompt_logprobs_for_token
.
items
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment