Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
002800f0
Unverified
Commit
002800f0
authored
Sep 04, 2023
by
Zhuohan Li
Committed by
GitHub
Sep 04, 2023
Browse files
Align vLLM's beam search implementation with HF generate (#857)
parent
e15932bb
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
119 additions
and
31 deletions
+119
-31
vllm/outputs.py
vllm/outputs.py
+6
-4
vllm/sampling_params.py
vllm/sampling_params.py
+34
-3
vllm/sequence.py
vllm/sequence.py
+77
-22
vllm/worker/worker.py
vllm/worker/worker.py
+2
-2
No files found.
vllm/outputs.py
View file @
002800f0
...
...
@@ -75,10 +75,12 @@ class RequestOutput:
# Get the top-n sequences.
n
=
seq_group
.
sampling_params
.
n
seqs
=
seq_group
.
get_seqs
()
assert
n
<=
len
(
seqs
)
sorted_seqs
=
sorted
(
seqs
,
key
=
lambda
seq
:
seq
.
get_cumulative_logprob
(),
reverse
=
True
)
if
seq_group
.
sampling_params
.
use_beam_search
:
sorting_key
=
lambda
seq
:
seq
.
get_beam_search_score
(
seq_group
.
sampling_params
.
length_penalty
)
else
:
sorting_key
=
lambda
seq
:
seq
.
get_cumulative_logprob
()
sorted_seqs
=
sorted
(
seqs
,
key
=
sorting_key
,
reverse
=
True
)
top_n_seqs
=
sorted_seqs
[:
n
]
# Create the outputs.
...
...
vllm/sampling_params.py
View file @
002800f0
...
...
@@ -34,6 +34,15 @@ class SamplingParams:
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
ignore_eos: Whether to ignore the EOS token and continue generating
...
...
@@ -52,6 +61,8 @@ class SamplingParams:
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
use_beam_search
:
bool
=
False
,
length_penalty
:
float
=
1.0
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
stop
:
Union
[
None
,
str
,
List
[
str
]]
=
None
,
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
...
...
@@ -65,6 +76,8 @@ class SamplingParams:
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
use_beam_search
=
use_beam_search
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
if
stop
is
None
:
self
.
stop
=
[]
elif
isinstance
(
stop
,
str
):
...
...
@@ -78,9 +91,11 @@ class SamplingParams:
self
.
_verify_args
()
if
self
.
use_beam_search
:
self
.
_verify_beam_search
()
elif
self
.
temperature
<
_SAMPLING_EPS
:
# Zero temperature means greedy sampling.
self
.
_verify_greedy_sampling
()
else
:
self
.
_verify_non_beam_search
()
if
self
.
temperature
<
_SAMPLING_EPS
:
# Zero temperature means greedy sampling.
self
.
_verify_greedy_sampling
()
def
_verify_args
(
self
)
->
None
:
if
self
.
n
<
1
:
...
...
@@ -119,6 +134,20 @@ class SamplingParams:
raise
ValueError
(
"top_p must be 1 when using beam search."
)
if
self
.
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using beam search."
)
if
self
.
early_stopping
not
in
[
True
,
False
,
"never"
]:
raise
ValueError
(
f
"early_stopping must be True, False, or 'never', "
f
"got
{
self
.
early_stopping
}
."
)
def
_verify_non_beam_search
(
self
)
->
None
:
if
self
.
early_stopping
is
not
False
:
raise
ValueError
(
"early_stopping is not effective and must be "
"False when not using beam search."
)
if
(
self
.
length_penalty
<
1.0
-
_SAMPLING_EPS
or
self
.
length_penalty
>
1.0
+
_SAMPLING_EPS
):
raise
ValueError
(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search."
)
def
_verify_greedy_sampling
(
self
)
->
None
:
if
self
.
best_of
>
1
:
...
...
@@ -138,6 +167,8 @@ class SamplingParams:
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
f
"stop=
{
self
.
stop
}
, "
f
"ignore_eos=
{
self
.
ignore_eos
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
...
...
vllm/sequence.py
View file @
002800f0
...
...
@@ -69,6 +69,9 @@ class SequenceData:
def
get_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
+
len
(
self
.
prompt_token_ids
)
def
get_prompt_len
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
def
get_output_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
...
...
@@ -155,6 +158,9 @@ class Sequence:
def
get_len
(
self
)
->
int
:
return
self
.
data
.
get_len
()
def
get_prompt_len
(
self
)
->
int
:
return
self
.
data
.
get_prompt_len
()
def
get_output_len
(
self
)
->
int
:
return
self
.
data
.
get_output_len
()
...
...
@@ -170,14 +176,32 @@ class Sequence:
def
get_cumulative_logprob
(
self
)
->
float
:
return
self
.
data
.
cumulative_logprob
def
get_beam_search_score
(
self
,
length_penalty
:
float
=
0.0
,
seq_len
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
)
->
float
:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
if
seq_len
is
None
:
seq_len
=
self
.
get_len
()
# Note: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if
(
eos_token_id
is
not
None
and
self
.
get_last_token_id
()
==
eos_token_id
):
seq_len
-=
1
return
self
.
get_cumulative_logprob
()
/
(
seq_len
**
length_penalty
)
def
is_finished
(
self
)
->
bool
:
return
SequenceStatus
.
is_finished
(
self
.
status
)
def
fork
(
self
,
child_seq
:
"Sequence"
)
->
None
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
data
=
copy
.
deepcopy
(
self
.
data
)
def
fork
(
self
,
new_seq_id
:
int
)
->
"Sequence"
:
new_seq
=
copy
.
deepcopy
(
self
)
new_seq
.
seq_id
=
new_seq_id
return
new_seq
def
__repr__
(
self
)
->
str
:
return
(
f
"Sequence(seq_id=
{
self
.
seq_id
}
, "
...
...
@@ -203,35 +227,66 @@ class SequenceGroup:
arrival_time
:
float
,
)
->
None
:
self
.
request_id
=
request_id
self
.
seqs
=
seq
s
self
.
seqs
_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
sampling_params
=
sampling_params
self
.
arrival_time
=
arrival_time
def
get_max_num_running_seqs
(
self
)
->
int
:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if
self
.
sampling_params
.
use_beam_search
:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
return
self
.
sampling_params
.
best_of
else
:
if
self
.
sampling_params
.
best_of
>
self
.
num_seqs
():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
return
self
.
sampling_params
.
best_of
# At sampling stages, return the number of actual sequences
# running.
return
self
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
def
get_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
,
)
->
List
[
Sequence
]:
if
status
is
None
:
return
self
.
seqs
return
list
(
self
.
seqs
_dict
.
values
())
else
:
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
]
return
[
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
seq
.
status
==
status
]
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
return
[
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
seq
.
is_finished
()]
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
return
len
(
self
.
get_seqs
(
status
))
def
find
(
self
,
seq_id
:
int
)
->
Sequence
:
for
seq
in
self
.
seqs
:
if
seq
.
seq_id
==
seq_id
:
return
seq
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
if
seq_id
not
in
self
.
seqs_dict
:
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
return
self
.
seqs_dict
[
seq_id
]
def
add
(
self
,
seq
:
Sequence
)
->
None
:
if
seq
.
seq_id
in
self
.
seqs_dict
:
raise
ValueError
(
f
"Sequence
{
seq
.
seq_id
}
already exists."
)
self
.
seqs_dict
[
seq
.
seq_id
]
=
seq
def
remove
(
self
,
seq_id
:
int
)
->
None
:
if
seq_id
not
in
self
.
seqs_dict
:
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
del
self
.
seqs_dict
[
seq_id
]
def
is_finished
(
self
)
->
bool
:
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
get_
seqs
()
)
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
f
"sampling_params=
{
self
.
sampling_params
}
, "
f
"num_seqs=
{
len
(
self
.
seqs
)
}
)"
)
f
"num_seqs=
{
len
(
self
.
seqs
_dict
)
}
)"
)
class
SequenceGroupMetadata
:
...
...
@@ -266,7 +321,6 @@ class SequenceOutputs:
"""The model output associated with a sequence.
Args:
seq_id: The ID of the sequence.
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
...
...
@@ -276,26 +330,27 @@ class SequenceOutputs:
def
__init__
(
self
,
seq_id
:
int
,
parent_seq_id
:
int
,
output_token
:
int
,
logprobs
:
Dict
[
int
,
float
],
)
->
None
:
self
.
seq_id
=
seq_id
self
.
parent_seq_id
=
parent_seq_id
self
.
output_token
=
output_token
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceOutputs(seq_id=
{
self
.
seq_id
}
, "
f
"parent_seq_id=
{
self
.
parent_seq_id
}
, "
return
(
f
"SequenceOutputs(parent_seq_id=
{
self
.
parent_seq_id
}
, "
f
"output_token=
{
self
.
output_token
}
), "
f
"logprobs=
{
self
.
logprobs
}
"
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
SequenceOutputs
):
return
NotImplemented
return
(
self
.
seq_id
==
other
.
seq_id
and
self
.
parent_seq_id
==
other
.
parent_seq_id
return
NotImplementedError
()
return
(
self
.
parent_seq_id
==
other
.
parent_seq_id
and
self
.
output_token
==
other
.
output_token
and
self
.
logprobs
==
other
.
logprobs
)
# For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token.
SamplerOutput
=
List
[
List
[
SequenceOutputs
]]
vllm/worker/worker.py
View file @
002800f0
...
...
@@ -11,7 +11,7 @@ from vllm.model_executor import get_model, InputMetadata, set_random_seed
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initialize_model_parallel
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
,
SequenceOutputs
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.utils
import
get_gpu_memory
...
...
@@ -260,7 +260,7 @@ class Worker:
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
# Issue cache operations.
issued_cache_op
=
False
if
blocks_to_swap_in
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment