Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aabe8f40
Unverified
Commit
aabe8f40
authored
Apr 03, 2024
by
Matthias Gerstgrasser
Committed by
GitHub
Apr 03, 2024
Browse files
[Core] [Frontend] Make detokenization optional (#3749)
Co-authored-by:
Nick Hill
<
nickhill@us.ibm.com
>
parent
498eb5cf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
9 deletions
+53
-9
tests/engine/test_detokenization.py
tests/engine/test_detokenization.py
+32
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+11
-9
vllm/sampling_params.py
vllm/sampling_params.py
+10
-0
No files found.
tests/engine/test_detokenization.py
0 → 100644
View file @
aabe8f40
import
pytest
from
vllm.entrypoints.llm
import
LLM
from
vllm.sampling_params
import
SamplingParams
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/opt-125m"
])
def
test_computed_prefix_blocks
(
model
:
str
):
# This test checks if the engine generates completions both with and
# without optional detokenization, that detokenization includes text
# and no-detokenization doesn't, and that both completions have the same
# token_ids.
prompt
=
(
"You are a helpful assistant. How do I build a car from cardboard and "
"paper clips? Is there an easy to follow video tutorial available "
"online for free?"
)
llm
=
LLM
(
model
=
model
)
sampling_params
=
SamplingParams
(
max_tokens
=
10
,
temperature
=
0.0
,
detokenize
=
False
)
outputs_no_detokenization
=
llm
.
generate
(
prompt
,
sampling_params
)[
0
].
outputs
[
0
]
sampling_params
.
detokenize
=
True
outputs_with_detokenization
=
llm
.
generate
(
prompt
,
sampling_params
)[
0
].
outputs
[
0
]
assert
outputs_no_detokenization
.
text
==
''
assert
outputs_with_detokenization
.
text
!=
''
assert
outputs_no_detokenization
.
token_ids
==
\
outputs_with_detokenization
.
token_ids
vllm/engine/llm_engine.py
View file @
aabe8f40
...
@@ -432,7 +432,7 @@ class LLMEngine:
...
@@ -432,7 +432,7 @@ class LLMEngine:
# Process prompt logprobs
# Process prompt logprobs
prompt_logprobs
=
outputs
.
prompt_logprobs
prompt_logprobs
=
outputs
.
prompt_logprobs
if
prompt_logprobs
is
not
None
:
if
prompt_logprobs
is
not
None
and
seq_group
.
sampling_params
.
detokenize
:
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
prompt_logprobs
)
seq_group
,
prompt_logprobs
)
seq_group
.
prompt_logprobs
=
prompt_logprobs
seq_group
.
prompt_logprobs
=
prompt_logprobs
...
@@ -478,8 +478,9 @@ class LLMEngine:
...
@@ -478,8 +478,9 @@ class LLMEngine:
child_seqs
.
append
((
parent
,
parent
))
child_seqs
.
append
((
parent
,
parent
))
for
seq
,
_
in
child_seqs
:
for
seq
,
_
in
child_seqs
:
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
if
seq_group
.
sampling_params
.
detokenize
:
seq_group
.
sampling_params
)
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
seq_group
.
sampling_params
)
self
.
_check_stop
(
seq
,
seq_group
.
sampling_params
)
self
.
_check_stop
(
seq
,
seq_group
.
sampling_params
)
# Non-beam search case
# Non-beam search case
...
@@ -791,12 +792,13 @@ class LLMEngine:
...
@@ -791,12 +792,13 @@ class LLMEngine:
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
return
return
for
stop_str
in
sampling_params
.
stop
:
if
sampling_params
.
detokenize
:
if
seq
.
output_text
.
endswith
(
stop_str
):
for
stop_str
in
sampling_params
.
stop
:
self
.
_finalize_sequence
(
seq
,
sampling_params
,
stop_str
)
if
seq
.
output_text
.
endswith
(
stop_str
):
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
self
.
_finalize_sequence
(
seq
,
sampling_params
,
stop_str
)
seq
.
stop_reason
=
stop_str
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
seq
.
stop_reason
=
stop_str
return
last_token_id
=
seq
.
get_last_token_id
()
last_token_id
=
seq
.
get_last_token_id
()
if
last_token_id
in
sampling_params
.
stop_token_ids
:
if
last_token_id
in
sampling_params
.
stop_token_ids
:
stop_str
=
self
.
get_tokenizer_for_seq
(
seq
).
convert_ids_to_tokens
(
stop_str
=
self
.
get_tokenizer_for_seq
(
seq
).
convert_ids_to_tokens
(
...
...
vllm/sampling_params.py
View file @
aabe8f40
...
@@ -88,6 +88,7 @@ class SamplingParams:
...
@@ -88,6 +88,7 @@ class SamplingParams:
log probability of the sampled token, so there may be up to
log probability of the sampled token, so there may be up to
`logprobs+1` elements in the response.
`logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
tokens in the output. Defaults to True.
...
@@ -118,6 +119,7 @@ class SamplingParams:
...
@@ -118,6 +119,7 @@ class SamplingParams:
min_tokens
:
int
=
0
,
min_tokens
:
int
=
0
,
logprobs
:
Optional
[
int
]
=
None
,
logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
detokenize
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
...
@@ -150,6 +152,10 @@ class SamplingParams:
...
@@ -150,6 +152,10 @@ class SamplingParams:
self
.
min_tokens
=
min_tokens
self
.
min_tokens
=
min_tokens
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
self
.
prompt_logprobs
=
prompt_logprobs
self
.
prompt_logprobs
=
prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
self
.
detokenize
=
detokenize
self
.
skip_special_tokens
=
skip_special_tokens
self
.
skip_special_tokens
=
skip_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
logits_processors
=
logits_processors
self
.
logits_processors
=
logits_processors
...
@@ -210,6 +216,10 @@ class SamplingParams:
...
@@ -210,6 +216,10 @@ class SamplingParams:
if
self
.
prompt_logprobs
is
not
None
and
self
.
prompt_logprobs
<
0
:
if
self
.
prompt_logprobs
is
not
None
and
self
.
prompt_logprobs
<
0
:
raise
ValueError
(
f
"prompt_logprobs must be non-negative, got "
raise
ValueError
(
f
"prompt_logprobs must be non-negative, got "
f
"
{
self
.
prompt_logprobs
}
."
)
f
"
{
self
.
prompt_logprobs
}
."
)
if
self
.
stop
and
not
self
.
detokenize
:
raise
ValueError
(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop."
)
def
_verify_beam_search
(
self
)
->
None
:
def
_verify_beam_search
(
self
)
->
None
:
if
self
.
best_of
==
1
:
if
self
.
best_of
==
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment