Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
c06170cc
Unverified
Commit
c06170cc
authored
Dec 15, 2023
by
Yunfeng Bai
Committed by
GitHub
Dec 15, 2023
Browse files
Add a flag to include stop string in output text (#1976)
parent
614856da
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
24 deletions
+32
-24
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-3
vllm/sampling_params.py
vllm/sampling_params.py
+28
-21
No files found.
vllm/engine/llm_engine.py
View file @
c06170cc
...
...
@@ -682,9 +682,10 @@ class LLMEngine:
"""Stop the finished sequences."""
for
stop_str
in
sampling_params
.
stop
:
if
seq
.
output_text
.
endswith
(
stop_str
):
# Truncate the output text so that the stop string is
# not included in the output.
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
if
not
sampling_params
.
include_stop_str_in_output
:
# Truncate the output text so that the stop string is
# not included in the output.
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
if
seq
.
get_last_token_id
()
in
sampling_params
.
stop_token_ids
:
...
...
vllm/sampling_params.py
View file @
c06170cc
...
...
@@ -2,6 +2,7 @@
from
enum
import
IntEnum
from
functools
import
cached_property
from
typing
import
Callable
,
List
,
Optional
,
Union
import
torch
_SAMPLING_EPS
=
1e-5
...
...
@@ -70,6 +71,8 @@ class SamplingParams:
stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens.
include_stop_str_in_output: Whether to include the stop strings in output
text. Defaults to False.
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
...
...
@@ -103,6 +106,7 @@ class SamplingParams:
early_stopping
:
Union
[
bool
,
str
]
=
False
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
include_stop_str_in_output
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
logprobs
:
Optional
[
int
]
=
None
,
...
...
@@ -140,6 +144,7 @@ class SamplingParams:
self
.
skip_special_tokens
=
skip_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
logits_processors
=
logits_processors
self
.
include_stop_str_in_output
=
include_stop_str_in_output
self
.
_verify_args
()
if
self
.
use_beam_search
:
self
.
_verify_beam_search
()
...
...
@@ -227,24 +232,26 @@ class SamplingParams:
return
SamplingType
.
RANDOM
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"best_of=
{
self
.
best_of
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"repetition_penalty=
{
self
.
repetition_penalty
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"min_p=
{
self
.
min_p
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
f
"stop=
{
self
.
stop
}
, "
f
"stop_token_ids=
{
self
.
stop_token_ids
}
, "
f
"ignore_eos=
{
self
.
ignore_eos
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"logprobs=
{
self
.
logprobs
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
, "
"spaces_between_special_tokens="
f
"
{
self
.
spaces_between_special_tokens
}
)"
)
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"best_of=
{
self
.
best_of
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"repetition_penalty=
{
self
.
repetition_penalty
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"min_p=
{
self
.
min_p
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
f
"stop=
{
self
.
stop
}
, "
f
"stop_token_ids=
{
self
.
stop_token_ids
}
, "
f
"include_stop_str_in_output=
{
self
.
include_stop_str_in_output
}
, "
f
"ignore_eos=
{
self
.
ignore_eos
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"logprobs=
{
self
.
logprobs
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
, "
"spaces_between_special_tokens="
f
"
{
self
.
spaces_between_special_tokens
}
)"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment