Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dfeb2ecc
Unverified
Commit
dfeb2ecc
authored
Mar 25, 2024
by
Nick Hill
Committed by
GitHub
Mar 25, 2024
Browse files
[Misc] Include matched stop string/token in responses (#2976)
Co-authored-by:
Sahil Suneja
<
sahilsuneja@gmail.com
>
parent
3a243095
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
97 additions
and
7 deletions
+97
-7
tests/samplers/test_stop_reason.py
tests/samplers/test_stop_reason.py
+59
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+5
-2
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+16
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+3
-1
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+3
-0
vllm/outputs.py
vllm/outputs.py
+10
-4
vllm/sequence.py
vllm/sequence.py
+1
-0
No files found.
tests/samplers/test_stop_reason.py
0 → 100644
View file @
dfeb2ecc
"""Test the different finish_reason="stop" situations during generation:
1. One of the provided stop strings
2. One of the provided stop tokens
3. The EOS token
Run `pytest tests/samplers/test_stop_reason.py`.
"""
import
pytest
import
transformers
from
vllm
import
SamplingParams
MODEL
=
"facebook/opt-350m"
STOP_STR
=
"."
SEED
=
42
MAX_TOKENS
=
1024
@
pytest
.
fixture
def
vllm_model
(
vllm_runner
):
vllm_model
=
vllm_runner
(
MODEL
)
yield
vllm_model
del
vllm_model
def
test_stop_reason
(
vllm_model
,
example_prompts
):
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
MODEL
)
stop_token_id
=
tokenizer
.
convert_tokens_to_ids
(
STOP_STR
)
llm
=
vllm_model
.
model
# test stop token
outputs
=
llm
.
generate
(
example_prompts
,
sampling_params
=
SamplingParams
(
seed
=
SEED
,
max_tokens
=
MAX_TOKENS
,
stop_token_ids
=
[
stop_token_id
]))
for
output
in
outputs
:
output
=
output
.
outputs
[
0
]
assert
output
.
finish_reason
==
"stop"
assert
output
.
stop_reason
==
stop_token_id
# test stop string
outputs
=
llm
.
generate
(
example_prompts
,
sampling_params
=
SamplingParams
(
seed
=
SEED
,
max_tokens
=
MAX_TOKENS
,
stop
=
"."
))
for
output
in
outputs
:
output
=
output
.
outputs
[
0
]
assert
output
.
finish_reason
==
"stop"
assert
output
.
stop_reason
==
STOP_STR
# test EOS token
outputs
=
llm
.
generate
(
example_prompts
,
sampling_params
=
SamplingParams
(
seed
=
SEED
,
max_tokens
=
MAX_TOKENS
))
for
output
in
outputs
:
output
=
output
.
outputs
[
0
]
assert
output
.
finish_reason
==
"length"
or
(
output
.
finish_reason
==
"stop"
and
output
.
stop_reason
is
None
)
vllm/engine/llm_engine.py
View file @
dfeb2ecc
...
...
@@ -740,12 +740,15 @@ class LLMEngine:
if
seq
.
output_text
.
endswith
(
stop_str
):
self
.
_finalize_sequence
(
seq
,
sampling_params
,
stop_str
)
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
if
seq
.
get_last_token_id
()
in
sampling_params
.
stop_token_ids
:
last_token_id
=
seq
.
get_last_token_id
()
if
last_token_id
in
sampling_params
.
stop_token_ids
:
stop_str
=
self
.
get_tokenizer_for_seq
(
seq
).
convert_ids_to_tokens
(
seq
.
get_
last_token_id
()
)
last_token_id
)
self
.
_finalize_sequence
(
seq
,
sampling_params
,
stop_str
)
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if the sequence has generated the EOS token.
...
...
vllm/entrypoints/openai/protocol.py
View file @
dfeb2ecc
...
...
@@ -338,6 +338,13 @@ class CompletionResponseChoice(BaseModel):
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
Field
(
default
=
None
,
description
=
(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
)
class
CompletionResponse
(
BaseModel
):
...
...
@@ -354,6 +361,13 @@ class CompletionResponseStreamChoice(BaseModel):
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
Field
(
default
=
None
,
description
=
(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
)
class
CompletionStreamResponse
(
BaseModel
):
...
...
@@ -375,6 +389,7 @@ class ChatCompletionResponseChoice(BaseModel):
message
:
ChatMessage
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
None
class
ChatCompletionResponse
(
BaseModel
):
...
...
@@ -396,6 +411,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
delta
:
DeltaMessage
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
None
class
ChatCompletionStreamResponse
(
BaseModel
):
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
dfeb2ecc
...
...
@@ -220,7 +220,8 @@ class OpenAIServingChat(OpenAIServing):
index
=
i
,
delta
=
DeltaMessage
(
content
=
delta_text
),
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
)
finish_reason
=
output
.
finish_reason
,
stop_reason
=
output
.
stop_reason
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
...
...
@@ -278,6 +279,7 @@ class OpenAIServingChat(OpenAIServing):
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
),
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
stop_reason
=
output
.
stop_reason
,
)
choices
.
append
(
choice_data
)
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
dfeb2ecc
...
...
@@ -266,6 +266,7 @@ class OpenAIServingCompletion(OpenAIServing):
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
finish_reason
=
output
.
finish_reason
stop_reason
=
output
.
stop_reason
if
output
.
finish_reason
is
not
None
:
# return final usage
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
...
...
@@ -286,6 +287,7 @@ class OpenAIServingCompletion(OpenAIServing):
text
=
delta_text
,
logprobs
=
logprobs
,
finish_reason
=
finish_reason
,
stop_reason
=
stop_reason
,
)
],
usage
=
final_usage
,
...
...
@@ -342,6 +344,7 @@ class OpenAIServingCompletion(OpenAIServing):
text
=
output_text
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
stop_reason
=
output
.
stop_reason
,
)
choices
.
append
(
choice_data
)
...
...
vllm/outputs.py
View file @
dfeb2ecc
import
time
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Union
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
...
...
@@ -18,6 +18,9 @@ class CompletionOutput:
logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested.
finish_reason: The reason why the sequence is finished.
stop_reason: The stop string or token id that caused the completion
to stop, None if the completion finished for some other reason
including encountering the EOS token.
lora_request: The LoRA request that was used to generate the output.
"""
...
...
@@ -29,6 +32,7 @@ class CompletionOutput:
cumulative_logprob
:
float
,
logprobs
:
Optional
[
SampleLogprobs
],
finish_reason
:
Optional
[
str
]
=
None
,
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
self
.
index
=
index
...
...
@@ -37,6 +41,7 @@ class CompletionOutput:
self
.
cumulative_logprob
=
cumulative_logprob
self
.
logprobs
=
logprobs
self
.
finish_reason
=
finish_reason
self
.
stop_reason
=
stop_reason
self
.
lora_request
=
lora_request
def
finished
(
self
)
->
bool
:
...
...
@@ -48,7 +53,8 @@ class CompletionOutput:
f
"token_ids=
{
self
.
token_ids
}
, "
f
"cumulative_logprob=
{
self
.
cumulative_logprob
}
, "
f
"logprobs=
{
self
.
logprobs
}
, "
f
"finish_reason=
{
self
.
finish_reason
}
)"
)
f
"finish_reason=
{
self
.
finish_reason
}
, "
f
"stop_reason=
{
self
.
stop_reason
}
)"
)
class
RequestOutput
:
...
...
@@ -111,8 +117,8 @@ class RequestOutput:
seq
.
get_output_token_ids
(),
seq
.
get_cumulative_logprob
(),
seq
.
output_logprobs
if
include_logprobs
else
None
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
)
for
seq
in
top_n_seqs
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
,
seq
.
stop_reason
)
for
seq
in
top_n_seqs
]
# Every sequence in the sequence group should have the same prompt.
...
...
vllm/sequence.py
View file @
dfeb2ecc
...
...
@@ -183,6 +183,7 @@ class Sequence:
# Initialize the logical token blocks with the prompt token ids.
self
.
_append_tokens_to_blocks
(
prompt_token_ids
)
self
.
status
=
SequenceStatus
.
WAITING
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
# Used for incremental detokenization
self
.
prefix_offset
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment