Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
002800f0
Unverified
Commit
002800f0
authored
Sep 04, 2023
by
Zhuohan Li
Committed by
GitHub
Sep 04, 2023
Browse files
Align vLLM's beam search implementation with HF generate (#857)
parent
e15932bb
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
119 additions
and
31 deletions
+119
-31
vllm/outputs.py
vllm/outputs.py
+6
-4
vllm/sampling_params.py
vllm/sampling_params.py
+34
-3
vllm/sequence.py
vllm/sequence.py
+77
-22
vllm/worker/worker.py
vllm/worker/worker.py
+2
-2
No files found.
vllm/outputs.py
View file @
002800f0
...
...
@@ -75,10 +75,12 @@ class RequestOutput:
# Get the top-n sequences.
n
=
seq_group
.
sampling_params
.
n
seqs
=
seq_group
.
get_seqs
()
assert
n
<=
len
(
seqs
)
sorted_seqs
=
sorted
(
seqs
,
key
=
lambda
seq
:
seq
.
get_cumulative_logprob
(),
reverse
=
True
)
if
seq_group
.
sampling_params
.
use_beam_search
:
sorting_key
=
lambda
seq
:
seq
.
get_beam_search_score
(
seq_group
.
sampling_params
.
length_penalty
)
else
:
sorting_key
=
lambda
seq
:
seq
.
get_cumulative_logprob
()
sorted_seqs
=
sorted
(
seqs
,
key
=
sorting_key
,
reverse
=
True
)
top_n_seqs
=
sorted_seqs
[:
n
]
# Create the outputs.
...
...
vllm/sampling_params.py
View file @
002800f0
...
...
@@ -34,6 +34,15 @@ class SamplingParams:
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
ignore_eos: Whether to ignore the EOS token and continue generating
...
...
@@ -52,6 +61,8 @@ class SamplingParams:
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
use_beam_search
:
bool
=
False
,
length_penalty
:
float
=
1.0
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
stop
:
Union
[
None
,
str
,
List
[
str
]]
=
None
,
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
...
...
@@ -65,6 +76,8 @@ class SamplingParams:
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
use_beam_search
=
use_beam_search
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
if
stop
is
None
:
self
.
stop
=
[]
elif
isinstance
(
stop
,
str
):
...
...
@@ -78,9 +91,11 @@ class SamplingParams:
self
.
_verify_args
()
if
self
.
use_beam_search
:
self
.
_verify_beam_search
()
elif
self
.
temperature
<
_SAMPLING_EPS
:
# Zero temperature means greedy sampling.
self
.
_verify_greedy_sampling
()
else
:
self
.
_verify_non_beam_search
()
if
self
.
temperature
<
_SAMPLING_EPS
:
# Zero temperature means greedy sampling.
self
.
_verify_greedy_sampling
()
def
_verify_args
(
self
)
->
None
:
if
self
.
n
<
1
:
...
...
@@ -119,6 +134,20 @@ class SamplingParams:
raise
ValueError
(
"top_p must be 1 when using beam search."
)
if
self
.
top_k
!=
-
1
:
raise
ValueError
(
"top_k must be -1 when using beam search."
)
if
self
.
early_stopping
not
in
[
True
,
False
,
"never"
]:
raise
ValueError
(
f
"early_stopping must be True, False, or 'never', "
f
"got
{
self
.
early_stopping
}
."
)
def
_verify_non_beam_search
(
self
)
->
None
:
if
self
.
early_stopping
is
not
False
:
raise
ValueError
(
"early_stopping is not effective and must be "
"False when not using beam search."
)
if
(
self
.
length_penalty
<
1.0
-
_SAMPLING_EPS
or
self
.
length_penalty
>
1.0
+
_SAMPLING_EPS
):
raise
ValueError
(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search."
)
def
_verify_greedy_sampling
(
self
)
->
None
:
if
self
.
best_of
>
1
:
...
...
@@ -138,6 +167,8 @@ class SamplingParams:
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"use_beam_search=
{
self
.
use_beam_search
}
, "
f
"length_penalty=
{
self
.
length_penalty
}
, "
f
"early_stopping=
{
self
.
early_stopping
}
, "
f
"stop=
{
self
.
stop
}
, "
f
"ignore_eos=
{
self
.
ignore_eos
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
...
...
vllm/sequence.py
View file @
002800f0
...
...
@@ -69,6 +69,9 @@ class SequenceData:
def
get_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
+
len
(
self
.
prompt_token_ids
)
def
get_prompt_len
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
def
get_output_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
...
...
@@ -155,6 +158,9 @@ class Sequence:
def
get_len
(
self
)
->
int
:
return
self
.
data
.
get_len
()
def
get_prompt_len
(
self
)
->
int
:
return
self
.
data
.
get_prompt_len
()
def
get_output_len
(
self
)
->
int
:
return
self
.
data
.
get_output_len
()
...
...
@@ -170,14 +176,32 @@ class Sequence:
def
get_cumulative_logprob
(
self
)
->
float
:
return
self
.
data
.
cumulative_logprob
def
get_beam_search_score
(
self
,
length_penalty
:
float
=
0.0
,
seq_len
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
)
->
float
:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
if
seq_len
is
None
:
seq_len
=
self
.
get_len
()
# Note: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if
(
eos_token_id
is
not
None
and
self
.
get_last_token_id
()
==
eos_token_id
):
seq_len
-=
1
return
self
.
get_cumulative_logprob
()
/
(
seq_len
**
length_penalty
)
def
is_finished
(
self
)
->
bool
:
return
SequenceStatus
.
is_finished
(
self
.
status
)
def
fork
(
self
,
child_seq
:
"Sequence"
)
->
None
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
child_seq
.
data
=
copy
.
deepcopy
(
self
.
data
)
def
fork
(
self
,
new_seq_id
:
int
)
->
"Sequence"
:
new_seq
=
copy
.
deepcopy
(
self
)
new_seq
.
seq_id
=
new_seq_id
return
new_seq
def
__repr__
(
self
)
->
str
:
return
(
f
"Sequence(seq_id=
{
self
.
seq_id
}
, "
...
...
@@ -203,35 +227,66 @@ class SequenceGroup:
arrival_time
:
float
,
)
->
None
:
self
.
request_id
=
request_id
self
.
seqs
=
seq
s
self
.
seqs
_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
sampling_params
=
sampling_params
self
.
arrival_time
=
arrival_time
def
get_max_num_running_seqs
(
self
)
->
int
:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if
self
.
sampling_params
.
use_beam_search
:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
return
self
.
sampling_params
.
best_of
else
:
if
self
.
sampling_params
.
best_of
>
self
.
num_seqs
():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
return
self
.
sampling_params
.
best_of
# At sampling stages, return the number of actual sequences
# running.
return
self
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
def
get_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
,
)
->
List
[
Sequence
]:
if
status
is
None
:
return
self
.
seqs
return
list
(
self
.
seqs
_dict
.
values
())
else
:
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
]
return
[
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
seq
.
status
==
status
]
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
return
[
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
seq
.
is_finished
()]
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
return
len
(
self
.
get_seqs
(
status
))
def
find
(
self
,
seq_id
:
int
)
->
Sequence
:
for
seq
in
self
.
seqs
:
if
seq
.
seq_id
==
seq_id
:
return
seq
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
if
seq_id
not
in
self
.
seqs_dict
:
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
return
self
.
seqs_dict
[
seq_id
]
def
add
(
self
,
seq
:
Sequence
)
->
None
:
if
seq
.
seq_id
in
self
.
seqs_dict
:
raise
ValueError
(
f
"Sequence
{
seq
.
seq_id
}
already exists."
)
self
.
seqs_dict
[
seq
.
seq_id
]
=
seq
def
remove
(
self
,
seq_id
:
int
)
->
None
:
if
seq_id
not
in
self
.
seqs_dict
:
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
del
self
.
seqs_dict
[
seq_id
]
def
is_finished
(
self
)
->
bool
:
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
get_
seqs
()
)
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
f
"sampling_params=
{
self
.
sampling_params
}
, "
f
"num_seqs=
{
len
(
self
.
seqs
)
}
)"
)
f
"num_seqs=
{
len
(
self
.
seqs
_dict
)
}
)"
)
class
SequenceGroupMetadata
:
...
...
@@ -266,7 +321,6 @@ class SequenceOutputs:
"""The model output associated with a sequence.
Args:
seq_id: The ID of the sequence.
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
...
...
@@ -276,26 +330,27 @@ class SequenceOutputs:
def
__init__
(
self
,
seq_id
:
int
,
parent_seq_id
:
int
,
output_token
:
int
,
logprobs
:
Dict
[
int
,
float
],
)
->
None
:
self
.
seq_id
=
seq_id
self
.
parent_seq_id
=
parent_seq_id
self
.
output_token
=
output_token
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceOutputs(seq_id=
{
self
.
seq_id
}
, "
f
"parent_seq_id=
{
self
.
parent_seq_id
}
, "
return
(
f
"SequenceOutputs(parent_seq_id=
{
self
.
parent_seq_id
}
, "
f
"output_token=
{
self
.
output_token
}
), "
f
"logprobs=
{
self
.
logprobs
}
"
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
SequenceOutputs
):
return
NotImplemented
return
(
self
.
seq_id
==
other
.
seq_id
and
self
.
parent_seq_id
==
other
.
parent_seq_id
return
NotImplementedError
()
return
(
self
.
parent_seq_id
==
other
.
parent_seq_id
and
self
.
output_token
==
other
.
output_token
and
self
.
logprobs
==
other
.
logprobs
)
# For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token.
SamplerOutput
=
List
[
List
[
SequenceOutputs
]]
vllm/worker/worker.py
View file @
002800f0
...
...
@@ -11,7 +11,7 @@ from vllm.model_executor import get_model, InputMetadata, set_random_seed
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initialize_model_parallel
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
,
SequenceOutputs
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.utils
import
get_gpu_memory
...
...
@@ -260,7 +260,7 @@ class Worker:
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
# Issue cache operations.
issued_cache_op
=
False
if
blocks_to_swap_in
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment