Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
c06170cc
Unverified
Commit
c06170cc
authored
Dec 15, 2023
by
Yunfeng Bai
Committed by
GitHub
Dec 15, 2023
Browse files
Add a flag to include stop string in output text (#1976)
parent
614856da
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
24 deletions
+32
-24
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-3
vllm/sampling_params.py
vllm/sampling_params.py
+28
-21
No files found.
vllm/engine/llm_engine.py
View file @
c06170cc
...
@@ -682,9 +682,10 @@ class LLMEngine:
...
@@ -682,9 +682,10 @@ class LLMEngine:
"""Stop the finished sequences."""
"""Stop the finished sequences."""
for
stop_str
in
sampling_params
.
stop
:
for
stop_str
in
sampling_params
.
stop
:
if
seq
.
output_text
.
endswith
(
stop_str
):
if
seq
.
output_text
.
endswith
(
stop_str
):
# Truncate the output text so that the stop string is
if
not
sampling_params
.
include_stop_str_in_output
:
# not included in the output.
# Truncate the output text so that the stop string is
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
# not included in the output.
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
return
if
seq
.
get_last_token_id
()
in
sampling_params
.
stop_token_ids
:
if
seq
.
get_last_token_id
()
in
sampling_params
.
stop_token_ids
:
...
...
vllm/sampling_params.py
View file @
c06170cc
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
from
enum
import
IntEnum
from
enum
import
IntEnum
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
import
torch
import
torch
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -70,6 +71,8 @@ class SamplingParams:
...
@@ -70,6 +71,8 @@ class SamplingParams:
stop_token_ids: List of tokens that stop the generation when they are
stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens.
the stop tokens are special tokens.
include_stop_str_in_output: Whether to include the stop strings in output
text. Defaults to False.
ignore_eos: Whether to ignore the EOS token and continue generating
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
max_tokens: Maximum number of tokens to generate per output sequence.
...
@@ -103,6 +106,7 @@ class SamplingParams:
...
@@ -103,6 +106,7 @@ class SamplingParams:
early_stopping
:
Union
[
bool
,
str
]
=
False
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
include_stop_str_in_output
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
max_tokens
:
int
=
16
,
logprobs
:
Optional
[
int
]
=
None
,
logprobs
:
Optional
[
int
]
=
None
,
...
@@ -140,6 +144,7 @@ class SamplingParams:
...
@@ -140,6 +144,7 @@ class SamplingParams:
self
.
skip_special_tokens
=
skip_special_tokens
self
.
skip_special_tokens
=
skip_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
logits_processors
=
logits_processors
self
.
logits_processors
=
logits_processors
self
.
include_stop_str_in_output
=
include_stop_str_in_output
self
.
_verify_args
()
self
.
_verify_args
()
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
self
.
_verify_beam_search
()
self
.
_verify_beam_search
()
...
@@ -227,24 +232,26 @@ class SamplingParams:
...
@@ -227,24 +232,26 @@ class SamplingParams:
return
SamplingType
.
RANDOM
return
SamplingType
.
RANDOM
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
return
(
f
"best_of=
{
self
.
best_of
}
, "
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"best_of=
{
self
.
best_of
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"repetition_penalty=
{
self
.
repetition_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"repetition_penalty=
{
self
.
repetition_penalty
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"min_p=
{
self
.
min_p
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"min_p=
{
self
.
min_p
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"stop=
{
self
.
stop
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
f
"stop_token_ids=
{
self
.
stop_token_ids
}
, "
f
"stop=
{
self
.
stop
}
, "
f
"ignore_eos=
{
self
.
ignore_eos
}
, "
f
"stop_token_ids=
{
self
.
stop_token_ids
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"include_stop_str_in_output=
{
self
.
include_stop_str_in_output
}
, "
f
"logprobs=
{
self
.
logprobs
}
, "
f
"ignore_eos=
{
self
.
ignore_eos
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
, "
f
"logprobs=
{
self
.
logprobs
}
, "
"spaces_between_special_tokens="
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"
{
self
.
spaces_between_special_tokens
}
)"
)
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
, "
"spaces_between_special_tokens="
f
"
{
self
.
spaces_between_special_tokens
}
)"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment