Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0250dd68
Unverified
Commit
0250dd68
authored
Sep 23, 2024
by
youkaichao
Committed by
GitHub
Sep 23, 2024
Browse files
re-implement beam search on top of vllm core (#8726)
Co-authored-by:
Brendan Wong
<
bjwpokemon@gmail.com
>
parent
88577ac9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
171 additions
and
9 deletions
+171
-9
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+20
-4
tests/conftest.py
tests/conftest.py
+14
-0
tests/samplers/test_beam_search.py
tests/samplers/test_beam_search.py
+3
-3
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+134
-2
No files found.
benchmarks/benchmark_throughput.py
View file @
0250dd68
...
...
@@ -90,6 +90,7 @@ def run_vllm(
download_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
EngineArgs
.
load_format
,
disable_async_output_proc
:
bool
=
False
,
use_new_beam_search_impl
:
bool
=
False
,
)
->
float
:
from
vllm
import
LLM
,
SamplingParams
llm
=
LLM
(
...
...
@@ -132,9 +133,23 @@ def run_vllm(
max_tokens
=
output_len
,
))
if
not
use_new_beam_search_impl
:
start
=
time
.
perf_counter
()
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
end
=
time
.
perf_counter
()
else
:
assert
use_beam_search
prompts
=
[
prompt
for
prompt
,
_
,
_
in
requests
]
# output_len should be the same for all requests.
output_len
=
requests
[
0
][
2
]
for
prompt
,
input_len
,
_output_len
in
requests
:
assert
_output_len
==
output_len
start
=
time
.
perf_counter
()
llm
.
beam_search
(
prompts
,
beam_width
=
n
,
max_tokens
=
output_len
,
ignore_eos
=
True
)
end
=
time
.
perf_counter
()
return
end
-
start
...
...
@@ -336,7 +351,7 @@ def main(args: argparse.Namespace):
run_args
.
append
(
args
.
disable_frontend_multiprocessing
)
elapsed_time
=
uvloop
.
run
(
run_vllm_async
(
*
run_args
))
else
:
elapsed_time
=
run_vllm
(
*
run_args
)
elapsed_time
=
run_vllm
(
*
run_args
,
args
.
use_new_beam_search_impl
)
elif
args
.
backend
==
"hf"
:
assert
args
.
tensor_parallel_size
==
1
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
...
...
@@ -396,6 +411,7 @@ if __name__ == "__main__":
default
=
1
,
help
=
"Number of generated sequences per prompt."
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--use-new-beam-search-impl"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
1000
,
...
...
tests/conftest.py
View file @
0250dd68
...
...
@@ -798,6 +798,20 @@ class VllmRunner:
outputs
=
self
.
generate
(
prompts
,
beam_search_params
)
return
outputs
def
generate_beam_search_new
(
self
,
prompts
:
Union
[
List
[
str
],
List
[
List
[
int
]]],
beam_width
:
int
,
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
outputs
=
self
.
model
.
beam_search
(
prompts
,
beam_width
,
max_tokens
)
returned_outputs
=
[]
for
output
in
outputs
:
token_ids
=
[
x
.
tokens
for
x
in
output
.
sequences
]
texts
=
[
x
.
text
for
x
in
output
.
sequences
]
returned_outputs
.
append
((
token_ids
,
texts
))
return
returned_outputs
def
encode
(
self
,
prompts
:
List
[
str
])
->
List
[
List
[
float
]]:
req_outputs
=
self
.
model
.
encode
(
prompts
)
outputs
=
[]
...
...
tests/samplers/test_beam_search.py
View file @
0250dd68
...
...
@@ -9,7 +9,7 @@ import pytest
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS
=
[
128
]
MAX_TOKENS
=
[
64
]
BEAM_WIDTHS
=
[
4
]
MODELS
=
[
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
]
...
...
@@ -33,8 +33,8 @@ def test_beam_search_single_input(
max_tokens
)
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_beam_search
(
example_prompts
,
beam_width
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_beam_search
_new
(
example_prompts
,
beam_width
,
max_tokens
)
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
hf_output_texts
=
hf_outputs
[
i
]
...
...
vllm/entrypoints/llm.py
View file @
0250dd68
import
itertools
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Union
,
cast
,
overload
)
from
dataclasses
import
dataclass
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
,
cast
,
overload
)
from
tqdm
import
tqdm
...
...
@@ -30,6 +32,37 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
logger
=
init_logger
(
__name__
)
@
dataclass
class
BeamSearchSequence
:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens
:
List
[
int
]
cum_logprob
:
float
=
0.0
text
:
Optional
[
str
]
=
None
@
dataclass
class
BeamSearchOutput
:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences
:
List
[
BeamSearchSequence
]
class
BeamSearchInstance
:
def
__init__
(
self
,
prompt_tokens
:
List
[
int
]):
self
.
beams
:
List
[
BeamSearchSequence
]
=
[
BeamSearchSequence
(
tokens
=
prompt_tokens
)
]
self
.
completed
:
List
[
BeamSearchSequence
]
=
[]
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
...
...
@@ -354,6 +387,105 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
def
beam_search
(
self
,
prompts
:
List
[
Union
[
str
,
List
[
int
]]],
beam_width
:
int
,
max_tokens
:
int
,
ignore_eos
:
bool
=
False
,
)
->
List
[
BeamSearchOutput
]:
"""
Generate sequences using beam search.
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
tokenizer
=
self
.
get_tokenizer
()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
0.0
)
instances
:
List
[
BeamSearchInstance
]
=
[]
for
prompt
in
prompts
:
prompt_tokens
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
instances
.
append
(
BeamSearchInstance
(
prompt_tokens
))
for
_
in
range
(
max_tokens
):
all_beams
:
List
[
BeamSearchSequence
]
=
list
(
sum
((
instance
.
beams
for
instance
in
instances
),
[]))
pos
=
[
0
]
+
list
(
itertools
.
accumulate
(
len
(
instance
.
beams
)
for
instance
in
instances
))
instance_start_and_end
:
List
[
Tuple
[
int
,
int
]]
=
list
(
zip
(
pos
[:
-
1
],
pos
[
1
:]))
if
len
(
all_beams
)
==
0
:
break
prompts_batch
=
[
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
)
for
beam
in
all_beams
]
# only runs for one step
# we don't need to use tqdm here
output
=
self
.
generate
(
prompts_batch
,
sampling_params
=
beam_search_params
,
use_tqdm
=
False
)
for
(
start
,
end
),
instance
in
zip
(
instance_start_and_end
,
instances
):
instance_new_beams
=
[]
for
i
in
range
(
start
,
end
):
current_beam
=
all_beams
[
i
]
result
=
output
[
i
]
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the max-model-len
# or abortion. we don't need to add it to the new beams.
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
if
token_id
==
tokenizer
.
eos_token_id
and
\
not
ignore_eos
:
instance
.
completed
.
append
(
new_beam
)
else
:
instance_new_beams
.
append
(
new_beam
)
sorted_beams
=
sorted
(
instance_new_beams
,
key
=
lambda
x
:
x
.
cum_logprob
,
reverse
=
True
)
instance
.
beams
=
sorted_beams
[:
beam_width
]
outputs
=
[]
for
instance
in
instances
:
instance
.
completed
.
extend
(
instance
.
beams
)
sorted_completed
=
sorted
(
instance
.
completed
,
key
=
lambda
x
:
x
.
cum_logprob
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
)
outputs
.
append
(
BeamSearchOutput
(
sequences
=
best_beams
))
return
outputs
def
chat
(
self
,
messages
:
List
[
ChatCompletionMessageParam
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment