Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dfba529b
Unverified
Commit
dfba529b
authored
May 29, 2024
by
Junichi Sato
Committed by
GitHub
May 28, 2024
Browse files
[Bugfix] Remove the last EOS token unless explicitly specified (#5077)
parent
5ae5ed1e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
91 additions
and
0 deletions
+91
-0
tests/engine/output_processor/test_stop_checker.py
tests/engine/output_processor/test_stop_checker.py
+86
-0
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+5
-0
No files found.
tests/engine/output_processor/test_stop_checker.py
0 → 100644
View file @
dfba529b
from
unittest.mock
import
MagicMock
import
pytest
from
transformers
import
PreTrainedTokenizer
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Logprob
,
Sequence
,
SequenceStatus
def
sequence_with_eos
(
text
:
str
,
eos_token
:
str
,
eos_token_id
:
int
)
->
Sequence
:
"""
Create a Sequence that ends with an EOS token.
"""
seq
=
Sequence
(
seq_id
=
0
,
prompt
=
""
,
prompt_token_ids
=
[],
block_size
=
16
,
eos_token_id
=
eos_token_id
,
)
seq
.
output_text
=
text
+
eos_token
offset
=
eos_token_id
+
1
for
i
in
range
(
offset
,
len
(
text
)
+
offset
):
seq
.
append_token_id
(
token_id
=
i
,
logprobs
=
{
i
:
Logprob
(
0.0
)})
seq
.
append_token_id
(
token_id
=
eos_token_id
,
logprobs
=
{
eos_token_id
:
Logprob
(
0.0
)})
seq
.
status
=
SequenceStatus
.
RUNNING
return
seq
@
pytest
.
mark
.
parametrize
([
"text_wo_eos"
,
"eos_token"
,
"eos_token_id"
],
[
(
"This text ends with EOS token"
,
"</s>"
,
2
),
])
@
pytest
.
mark
.
parametrize
(
"ignore_eos"
,
[
True
,
False
,
None
])
@
pytest
.
mark
.
parametrize
(
"include_stop_str_in_output"
,
[
True
,
False
,
None
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_stop_on_eos_token
(
text_wo_eos
:
str
,
eos_token
:
str
,
eos_token_id
:
int
,
ignore_eos
:
bool
,
include_stop_str_in_output
:
bool
):
"""
Test the behavior of the StopChecker's maybe_stop_sequence method
when an EOS token is encountered.
This test covers:
- When the EOS token should stop the sequence and be removed from the output
- When the EOS token should stop the sequence and be included in the output
- When the EOS token should be ignored, and the sequence continues
"""
tokenizer
=
MagicMock
(
spec
=
PreTrainedTokenizer
)
get_tokenizer_for_seq
=
MagicMock
(
return_value
=
tokenizer
)
stop_checker
=
StopChecker
(
max_model_len
=
1024
,
get_tokenizer_for_seq
=
get_tokenizer_for_seq
)
seq
=
sequence_with_eos
(
text
=
text_wo_eos
,
eos_token
=
eos_token
,
eos_token_id
=
eos_token_id
,
)
new_char_count
=
len
(
eos_token
)
# Note that `stop` and `stop_token_ids` are not specified
sampling_params
=
SamplingParams
(
min_tokens
=
1
,
ignore_eos
=
ignore_eos
,
include_stop_str_in_output
=
include_stop_str_in_output
)
stop_checker
.
maybe_stop_sequence
(
seq
=
seq
,
new_char_count
=
new_char_count
,
sampling_params
=
sampling_params
,
)
if
ignore_eos
:
assert
seq
.
status
==
SequenceStatus
.
RUNNING
assert
seq
.
output_text
==
text_wo_eos
+
eos_token
elif
include_stop_str_in_output
:
assert
seq
.
status
==
SequenceStatus
.
FINISHED_STOPPED
assert
seq
.
output_text
==
text_wo_eos
+
eos_token
else
:
assert
seq
.
status
==
SequenceStatus
.
FINISHED_STOPPED
assert
seq
.
output_text
==
text_wo_eos
vllm/engine/output_processor/stop_checker.py
View file @
dfba529b
...
@@ -48,6 +48,11 @@ class StopChecker:
...
@@ -48,6 +48,11 @@ class StopChecker:
# Check if the sequence has generated the EOS token.
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment