Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e46a60aa
Unverified
Commit
e46a60aa
authored
Apr 11, 2024
by
Nick Hill
Committed by
GitHub
Apr 11, 2024
Browse files
[BugFix] Fix handling of stop strings and stop token ids (#3672)
parent
1e96c334
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
202 additions
and
37 deletions
+202
-37
tests/conftest.py
tests/conftest.py
+1
-1
tests/engine/test_stop_reason.py
tests/engine/test_stop_reason.py
+1
-1
tests/engine/test_stop_strings.py
tests/engine/test_stop_strings.py
+111
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+65
-33
vllm/outputs.py
vllm/outputs.py
+3
-1
vllm/sampling_params.py
vllm/sampling_params.py
+9
-0
vllm/sequence.py
vllm/sequence.py
+6
-0
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+6
-1
No files found.
tests/conftest.py
View file @
e46a60aa
...
@@ -401,7 +401,7 @@ class VllmRunner:
...
@@ -401,7 +401,7 @@ class VllmRunner:
cleanup
()
cleanup
()
@
pytest
.
fixture
@
pytest
.
fixture
(
scope
=
"session"
)
def
vllm_runner
():
def
vllm_runner
():
return
VllmRunner
return
VllmRunner
...
...
tests/
samplers
/test_stop_reason.py
→
tests/
engine
/test_stop_reason.py
View file @
e46a60aa
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
2. One of the provided stop tokens
2. One of the provided stop tokens
3. The EOS token
3. The EOS token
Run `pytest tests/
samplers
/test_stop_reason.py`.
Run `pytest tests/
engine
/test_stop_reason.py`.
"""
"""
import
pytest
import
pytest
...
...
tests/engine/test_stop_strings.py
0 → 100644
View file @
e46a60aa
from
typing
import
Any
,
List
,
Optional
import
pytest
from
vllm
import
CompletionOutput
,
LLMEngine
,
SamplingParams
MODEL
=
"meta-llama/llama-2-7b-hf"
MAX_TOKENS
=
200
@
pytest
.
fixture
(
scope
=
"session"
)
def
vllm_model
(
vllm_runner
):
return
vllm_runner
(
MODEL
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_stop_basic
(
vllm_model
):
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
stop
=
[
"."
],
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer organization"
,
expected_reason
=
"."
)
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
stop
=
[
"."
],
include_in_output
=
True
,
expected_output
=
"VLLM is a 100% volunteer organization."
,
expected_reason
=
"."
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_stop_multi_tokens
(
vllm_model
):
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
stop
=
[
"group of peo"
,
"short"
],
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer organization. We are a "
,
expected_reason
=
"group of peo"
)
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
stop
=
[
"group of peo"
,
"short"
],
include_in_output
=
True
,
expected_output
=
"VLLM is a 100% volunteer organization. We are a group of peo"
,
expected_reason
=
"group of peo"
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_stop_partial_token
(
vllm_model
):
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
stop
=
[
"gani"
],
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer or"
,
expected_reason
=
"gani"
)
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
stop
=
[
"gani"
],
include_in_output
=
True
,
expected_output
=
"VLLM is a 100% volunteer organi"
,
expected_reason
=
"gani"
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_stop_token_id
(
vllm_model
):
# token id 13013 => " organization"
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
stop_token_ids
=
[
13013
],
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer"
,
expected_reason
=
13013
)
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
stop_token_ids
=
[
13013
],
include_in_output
=
True
,
expected_output
=
"VLLM is a 100% volunteer organization"
,
expected_reason
=
13013
)
def
_test_stopping
(
llm_engine
:
LLMEngine
,
expected_output
:
str
,
expected_reason
:
Any
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
include_in_output
:
bool
=
False
)
->
None
:
llm_engine
.
add_request
(
"id"
,
"A story about vLLM:
\n
"
,
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
MAX_TOKENS
,
stop
=
stop
,
stop_token_ids
=
stop_token_ids
,
include_stop_str_in_output
=
include_in_output
,
),
None
)
output
:
Optional
[
CompletionOutput
]
=
None
output_text
=
""
stop_reason
=
None
while
llm_engine
.
has_unfinished_requests
():
(
request_output
,
)
=
llm_engine
.
step
()
(
output
,
)
=
request_output
.
outputs
# Ensure we don't backtrack
assert
output
.
text
.
startswith
(
output_text
)
output_text
=
output
.
text
stop_reason
=
output
.
stop_reason
assert
output
is
not
None
assert
output_text
==
expected_output
assert
stop_reason
==
expected_reason
vllm/engine/llm_engine.py
View file @
e46a60aa
...
@@ -501,9 +501,11 @@ class LLMEngine:
...
@@ -501,9 +501,11 @@ class LLMEngine:
for
seq
,
_
in
child_seqs
:
for
seq
,
_
in
child_seqs
:
if
seq_group
.
sampling_params
.
detokenize
:
if
seq_group
.
sampling_params
.
detokenize
:
self
.
detokenizer
.
decode_sequence_inplace
(
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
seq_group
.
sampling_params
)
seq
,
seq_group
.
sampling_params
)
self
.
_check_stop
(
seq
,
seq_group
.
sampling_params
)
else
:
new_char_count
=
0
self
.
_check_stop
(
seq
,
new_char_count
,
seq_group
.
sampling_params
)
# Non-beam search case
# Non-beam search case
if
not
seq_group
.
sampling_params
.
use_beam_search
:
if
not
seq_group
.
sampling_params
.
use_beam_search
:
...
@@ -798,56 +800,86 @@ class LLMEngine:
...
@@ -798,56 +800,86 @@ class LLMEngine:
time_e2e_requests
=
time_e2e_requests
,
time_e2e_requests
=
time_e2e_requests
,
)
)
def
_check_stop
(
self
,
seq
:
Sequence
,
def
_check_stop
(
self
,
seq
:
Sequence
,
new_char_count
:
int
,
sampling_params
:
SamplingParams
)
->
None
:
sampling_params
:
SamplingParams
)
->
None
:
"""Stop the finished sequences."""
"""Stop the finished sequences.
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
scheduler_config
.
max_model_len
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
new_char_count is the number of chars added to the
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
sequence's output text for the newly generated token
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
"""
return
# Check if the minimum number of tokens has been generated yet;
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
# skip the stop string/token checks if not
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
return
return
if
sampling_params
.
detokenize
:
# Check if the sequence has generated the EOS token.
for
stop_str
in
sampling_params
.
stop
:
if
((
not
sampling_params
.
ignore_eos
)
if
seq
.
output_text
.
endswith
(
stop_str
):
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
self
.
_finalize_sequence
(
seq
,
sampling_params
,
stop_str
)
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
seq
.
stop_reason
=
stop_str
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
()
last_token_id
=
seq
.
get_last_token_id
()
if
last_token_id
in
sampling_params
.
stop_token_ids
:
if
last_token_id
in
sampling_params
.
stop_token_ids
:
stop_str
=
self
.
get_tokenizer_for_seq
(
seq
).
convert_ids_to_tokens
(
if
new_char_count
and
(
last_token_id
)
not
sampling_params
.
include_stop_str_in_output
):
self
.
_finalize_sequence
(
seq
,
sampling_params
,
stop_str
)
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
seq
.
stop_reason
=
last_token_id
return
return
# Check if the sequence has generated the EOS token.
# Check if any stop strings are matched.
if
((
not
sampling_params
.
ignore_eos
)
stop_str
=
self
.
_check_stop_strings
(
seq
,
new_char_count
,
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
sampling_params
)
if
stop_str
is
not
None
:
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
return
def
_finalize_sequence
(
self
,
seq
:
Sequence
,
# Check if the sequence has reached max_model_len.
sampling_params
:
SamplingParams
,
if
seq
.
get_len
()
>
self
.
scheduler_config
.
max_model_len
:
stop_string
:
str
)
->
None
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
if
sampling_params
.
include_stop_str_in_output
:
return
return
if
stop_string
and
seq
.
output_text
.
endswith
(
stop_string
):
# Check if the sequence has reached max_tokens.
# Truncate the output text so that the stop string is
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
# not included in the output.
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_string
)]
return
@
staticmethod
def
_check_stop_strings
(
seq
:
Sequence
,
new_char_count
:
int
,
sampling_params
:
SamplingParams
)
->
Optional
[
str
]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns the stop string if matched or else None.
"""
if
not
new_char_count
:
return
None
for
stop_str
in
sampling_params
.
stop
:
stop_string_len
=
len
(
stop_str
)
# Avoid searching already-searched text.
stop_index
=
seq
.
output_text
.
find
(
stop_str
,
-
new_char_count
-
stop_string_len
)
if
stop_index
==
-
1
:
continue
if
sampling_params
.
include_stop_str_in_output
:
# Truncate to end of stop string.
stop_index
+=
stop_string_len
if
stop_index
>=
len
(
seq
.
output_text
):
# No truncation required.
return
stop_str
# Truncate the output text to either the beginning
# or end of the stop string.
seq
.
output_text
=
seq
.
output_text
[:
stop_index
]
return
stop_str
return
None
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_executor
.
add_lora
(
lora_request
)
return
self
.
model_executor
.
add_lora
(
lora_request
)
...
...
vllm/outputs.py
View file @
e46a60aa
...
@@ -112,8 +112,10 @@ class RequestOutput:
...
@@ -112,8 +112,10 @@ class RequestOutput:
# always has the logprobs of the sampled tokens even if the
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
# logprobs are not requested.
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
text_buffer_length
=
seq_group
.
sampling_params
.
output_text_buffer_length
outputs
=
[
outputs
=
[
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
output_text
,
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
get_output_text_to_return
(
text_buffer_length
),
seq
.
get_output_token_ids
(),
seq
.
get_output_token_ids
(),
seq
.
get_cumulative_logprob
(),
seq
.
get_cumulative_logprob
(),
seq
.
output_logprobs
if
include_logprobs
else
None
,
seq
.
output_logprobs
if
include_logprobs
else
None
,
...
...
vllm/sampling_params.py
View file @
e46a60aa
...
@@ -166,6 +166,13 @@ class SamplingParams:
...
@@ -166,6 +166,13 @@ class SamplingParams:
self
.
logits_processors
=
logits_processors
self
.
logits_processors
=
logits_processors
self
.
include_stop_str_in_output
=
include_stop_str_in_output
self
.
include_stop_str_in_output
=
include_stop_str_in_output
self
.
truncate_prompt_tokens
=
truncate_prompt_tokens
self
.
truncate_prompt_tokens
=
truncate_prompt_tokens
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
if
self
.
stop
and
not
include_stop_str_in_output
:
self
.
output_text_buffer_length
=
max
(
len
(
s
)
for
s
in
self
.
stop
)
-
1
else
:
self
.
output_text_buffer_length
=
0
self
.
_verify_args
()
self
.
_verify_args
()
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
self
.
_verify_beam_search
()
self
.
_verify_beam_search
()
...
@@ -226,6 +233,8 @@ class SamplingParams:
...
@@ -226,6 +233,8 @@ class SamplingParams:
and
self
.
truncate_prompt_tokens
<
1
):
and
self
.
truncate_prompt_tokens
<
1
):
raise
ValueError
(
f
"truncate_prompt_tokens must be >= 1, "
raise
ValueError
(
f
"truncate_prompt_tokens must be >= 1, "
f
"got
{
self
.
truncate_prompt_tokens
}
"
)
f
"got
{
self
.
truncate_prompt_tokens
}
"
)
if
any
(
not
stop_str
for
stop_str
in
self
.
stop
):
raise
ValueError
(
"stop cannot contain an empty string."
)
if
self
.
stop
and
not
self
.
detokenize
:
if
self
.
stop
and
not
self
.
detokenize
:
raise
ValueError
(
raise
ValueError
(
"stop strings are only supported when detokenize is True. "
"stop strings are only supported when detokenize is True. "
...
...
vllm/sequence.py
View file @
e46a60aa
...
@@ -235,6 +235,12 @@ class Sequence:
...
@@ -235,6 +235,12 @@ class Sequence:
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
def
get_output_text_to_return
(
self
,
buffer_length
:
int
):
# We return the full output text if the sequence is finished.
truncate
=
buffer_length
and
not
self
.
is_finished
()
return
self
.
output_text
[:
-
buffer_length
]
if
truncate
else
(
self
.
output_text
)
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
# TODO This can produce incorrect hash when block size > prompt size
# TODO This can produce incorrect hash when block size > prompt size
...
...
vllm/transformers_utils/detokenizer.py
View file @
e46a60aa
...
@@ -87,12 +87,15 @@ class Detokenizer:
...
@@ -87,12 +87,15 @@ class Detokenizer:
prev_tokens
.
extend
(
next_iter_tokens
)
prev_tokens
.
extend
(
next_iter_tokens
)
def
decode_sequence_inplace
(
self
,
seq
:
Sequence
,
def
decode_sequence_inplace
(
self
,
seq
:
Sequence
,
prms
:
SamplingParams
)
->
None
:
prms
:
SamplingParams
)
->
int
:
"""Decodes the new token for a sequence. In-place operation.
"""Decodes the new token for a sequence. In-place operation.
Args:
Args:
seq: The sequence to decode.
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
"""
all_input_ids
=
seq
.
get_token_ids
()
all_input_ids
=
seq
.
get_token_ids
()
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
...
@@ -151,6 +154,8 @@ class Detokenizer:
...
@@ -151,6 +154,8 @@ class Detokenizer:
seq
.
read_offset
=
read_offset
seq
.
read_offset
=
read_offset
seq
.
output_text
+=
new_decoded_token_text
seq
.
output_text
+=
new_decoded_token_text
return
len
(
new_decoded_token_text
)
def
_convert_tokens_to_string_with_added_encoders
(
def
_convert_tokens_to_string_with_added_encoders
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment