Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c746226
Unverified
Commit
8c746226
authored
Oct 07, 2024
by
Brendan Wong
Committed by
GitHub
Oct 08, 2024
Browse files
[Frontend] API support for beam search for MQLLMEngine (#9117)
parent
e1faa2a5
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
215 additions
and
106 deletions
+215
-106
tests/entrypoints/openai/test_completion.py
tests/entrypoints/openai/test_completion.py
+19
-24
vllm/beam_search.py
vllm/beam_search.py
+61
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+5
-7
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+107
-6
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+3
-34
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+10
-8
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+10
-8
vllm/utils.py
vllm/utils.py
+0
-19
No files found.
tests/entrypoints/openai/test_completion.py
View file @
8c746226
...
...
@@ -495,7 +495,6 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
assert
len
(
batch
.
choices
)
==
2
assert
batch
.
choices
[
0
].
text
==
batch
.
choices
[
1
].
text
try
:
# test n = 2
batch
=
await
client
.
completions
.
create
(
model
=
model_name
,
...
...
@@ -515,10 +514,6 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
2
].
text
,
"two copies of the same prompt should be the same"
assert
batch
.
choices
[
1
].
text
==
batch
.
choices
[
3
].
text
,
"two copies of the same prompt should be the same"
except
BadRequestError
as
e
:
# the only allowed exception is when beam search is not supported
# in the default mqllmengine
assert
"--disable-frontend-multiprocessing"
in
str
(
e
)
# test streaming
batch
=
await
client
.
completions
.
create
(
...
...
vllm/beam_search.py
0 → 100644
View file @
8c746226
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
@
dataclass
class
BeamSearchSequence
:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens
:
List
[
int
]
cum_logprob
:
float
=
0.0
text
:
Optional
[
str
]
=
None
@
dataclass
class
BeamSearchOutput
:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences
:
List
[
BeamSearchSequence
]
class
BeamSearchInstance
:
def
__init__
(
self
,
prompt_tokens
:
List
[
int
]):
self
.
beams
:
List
[
BeamSearchSequence
]
=
[
BeamSearchSequence
(
tokens
=
prompt_tokens
)
]
self
.
completed
:
List
[
BeamSearchSequence
]
=
[]
def
get_beam_search_score
(
tokens
:
List
[
int
],
cumulative_logprob
:
float
,
eos_token_id
:
int
,
length_penalty
:
float
=
1.0
,
)
->
float
:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
seq_len
=
len
(
tokens
)
if
tokens
[
-
1
]
==
eos_token_id
:
seq_len
-=
1
return
cumulative_logprob
/
(
seq_len
**
length_penalty
)
def
create_sort_beams_key_function
(
eos_token_id
:
int
,
length_penalty
:
float
):
def
sort_beams_key
(
x
:
BeamSearchSequence
)
->
float
:
return
get_beam_search_score
(
x
.
tokens
,
x
.
cum_logprob
,
eos_token_id
,
length_penalty
)
return
sort_beams_key
vllm/engine/async_llm_engine.py
View file @
8c746226
...
...
@@ -7,6 +7,7 @@ from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
from
weakref
import
ReferenceType
import
vllm.envs
as
envs
from
vllm.beam_search
import
BeamSearchSequence
,
create_sort_beams_key_function
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.core.scheduler
import
SchedulerOutputs
...
...
@@ -14,7 +15,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
,
SchedulerOutputState
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.entrypoints.llm
import
BeamSearchSequence
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.ray_utils
import
initialize_ray_cluster
...
...
@@ -33,7 +33,7 @@ from vllm.sequence import ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
collect_from_async_generator
,
deprecate_kwargs
,
get_beam_search_score
,
random_uuid
,
weak_bind
)
random_uuid
,
weak_bind
)
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
...
@@ -1052,16 +1052,14 @@ class AsyncLLMEngine:
temperature
=
params
.
temperature
length_penalty
=
params
.
length_penalty
def
sort_beams_key
(
x
:
BeamSearchSequence
)
->
float
:
return
get_beam_search_score
(
x
.
tokens
,
x
.
cum_logprob
,
tokenizer
.
eos_token_id
,
length_penalty
)
tokenizer
=
await
self
.
get_tokenizer
()
tokenizedPrompt
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
tokenizedLength
=
len
(
tokenizedPrompt
)
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
)
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
temperature
)
...
...
vllm/engine/multiprocessing/client.py
View file @
8c746226
...
...
@@ -2,8 +2,8 @@ import asyncio
import
copy
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Union
,
overload
)
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
List
,
Mapping
,
Optional
,
Union
,
overload
)
import
cloudpickle
import
zmq
...
...
@@ -12,6 +12,7 @@ from zmq import Frame # type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
vllm
import
PoolingParams
from
vllm.beam_search
import
BeamSearchSequence
,
create_sort_beams_key_function
from
vllm.config
import
DecodingConfig
,
EngineConfig
,
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
# yapf conflicts with isort for this block
...
...
@@ -27,14 +28,16 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCUProfileRequest
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.inputs
import
PromptType
from
vllm.inputs
import
PromptType
,
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingRequestOutput
,
RequestOutput
)
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.utils
import
deprecate_kwargs
from
vllm.utils
import
(
collect_from_async_generator
,
deprecate_kwargs
,
random_uuid
)
logger
=
init_logger
(
__name__
)
...
...
@@ -441,6 +444,104 @@ class MQLLMEngineClient:
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
)
async
def
beam_search
(
self
,
prompt
:
Union
[
PromptType
,
List
[
int
]],
request_id
:
str
,
params
:
BeamSearchParams
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
beam_width
=
params
.
beam_width
max_tokens
=
params
.
max_tokens
ignore_eos
=
params
.
ignore_eos
temperature
=
params
.
temperature
length_penalty
=
params
.
length_penalty
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
=
None
)
tokenizedPrompt
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
tokenizedLength
=
len
(
tokenizedPrompt
)
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
)
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
temperature
)
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenizedPrompt
,
cum_logprob
=
0
)]
completed
=
[]
for
_
in
range
(
max_tokens
):
prompts_batch
=
[
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
)
for
beam
in
all_beams
]
tasks
=
[]
request_id
=
f
"beam_search-
{
random_uuid
()
}
"
for
i
,
individual_prompt
in
enumerate
(
prompts_batch
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
task
=
asyncio
.
create_task
(
collect_from_async_generator
(
self
.
generate
(
individual_prompt
,
beam_search_params
,
request_id_item
)))
tasks
.
append
(
task
)
output
=
await
asyncio
.
gather
(
*
tasks
)
output
=
[
x
[
0
]
for
x
in
output
]
logger
.
info
(
output
)
new_beams
=
[]
for
i
,
current_beam
in
enumerate
(
all_beams
):
result
=
output
[
i
]
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
if
token_id
==
tokenizer
.
eos_token_id
and
\
not
ignore_eos
:
completed
.
append
(
new_beam
)
else
:
new_beams
.
append
(
new_beam
)
sorted_beams
=
sorted
(
new_beams
,
key
=
sort_beams_key
,
reverse
=
True
)
all_beams
=
sorted_beams
[:
beam_width
]
completed
.
extend
(
all_beams
)
sorted_completed
=
sorted
(
completed
,
key
=
sort_beams_key
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenizedLength
:])
beam_search_output
=
RequestOutput
(
request_id
=
request_id
,
prompt
=
prompt
,
outputs
=
[
CompletionOutput
(
text
=
beam
.
text
,
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
,
index
=
i
,
logprobs
=
beam
.
cum_logprob
,
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
finished
=
True
,
prompt_token_ids
=
tokenizedPrompt
,
prompt_logprobs
=
None
)
logger
.
info
(
beam_search_output
)
yield
beam_search_output
@
overload
# DEPRECATED
def
encode
(
self
,
...
...
vllm/entrypoints/llm.py
View file @
8c746226
import
itertools
import
warnings
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
,
cast
,
overload
)
from
tqdm
import
tqdm
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
BeamSearchSequence
,
get_beam_search_score
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
...
...
@@ -28,43 +29,11 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
Counter
,
deprecate_kwargs
,
get_beam_search_score
,
is_list_of
)
from
vllm.utils
import
Counter
,
deprecate_kwargs
,
is_list_of
logger
=
init_logger
(
__name__
)
@
dataclass
class
BeamSearchSequence
:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens
:
List
[
int
]
cum_logprob
:
float
=
0.0
text
:
Optional
[
str
]
=
None
@
dataclass
class
BeamSearchOutput
:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences
:
List
[
BeamSearchSequence
]
class
BeamSearchInstance
:
def
__init__
(
self
,
prompt_tokens
:
List
[
int
]):
self
.
beams
:
List
[
BeamSearchSequence
]
=
[
BeamSearchSequence
(
tokens
=
prompt_tokens
)
]
self
.
completed
:
List
[
BeamSearchSequence
]
=
[]
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
8c746226
...
...
@@ -10,6 +10,7 @@ from fastapi import Request
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
apply_hf_chat_template
,
...
...
@@ -236,15 +237,16 @@ class OpenAIServingChat(OpenAIServing):
log_tracing_disabled_warning
()
if
isinstance
(
sampling_params
,
BeamSearchParams
):
if
not
isinstance
(
self
.
engine_client
,
AsyncLLMEngine
):
raise
ValueError
(
"Beam search in the API server is only supported with"
" AsyncLLMEngine. please add "
"`--disable-frontend-multiprocessing` to "
"use beam search."
)
assert
isinstance
(
self
.
engine_client
,
(
AsyncLLMEngine
,
MQLLMEngineClient
)),
\
"Beam search is only supported with"
\
"AsyncLLMEngine and MQLLMEngineClient."
result_generator
=
self
.
engine_client
.
beam_search
(
engine_inputs
[
'prompt_token_ids'
],
request_id
,
sampling_params
)
engine_inputs
[
'prompt_token_ids'
],
request_id
,
sampling_params
,
)
else
:
result_generator
=
self
.
engine_client
.
generate
(
engine_inputs
,
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
8c746226
...
...
@@ -9,6 +9,7 @@ from fastapi import Request
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
...
...
@@ -150,15 +151,16 @@ class OpenAIServingCompletion(OpenAIServing):
log_tracing_disabled_warning
()
if
isinstance
(
sampling_params
,
BeamSearchParams
):
if
not
isinstance
(
self
.
engine_client
,
AsyncLLMEngine
):
raise
ValueError
(
"Beam search in the API server is only supported"
" with AsyncLLMEngine. please add "
"`--disable-frontend-multiprocessing` to "
"use beam search."
)
assert
isinstance
(
self
.
engine_client
,
(
AsyncLLMEngine
,
MQLLMEngineClient
)),
\
"Beam search is only supported with"
\
"AsyncLLMEngine and MQLLMEngineClient."
generator
=
self
.
engine_client
.
beam_search
(
prompt_inputs
[
"prompt_token_ids"
],
request_id_item
,
sampling_params
)
prompt_inputs
[
"prompt_token_ids"
],
request_id_item
,
sampling_params
,
)
else
:
generator
=
self
.
engine_client
.
generate
(
{
...
...
vllm/utils.py
View file @
8c746226
...
...
@@ -1370,22 +1370,3 @@ class AtomicCounter:
@
property
def
value
(
self
):
return
self
.
_value
def
get_beam_search_score
(
tokens
:
List
[
int
],
cumulative_logprob
:
float
,
eos_token_id
:
int
,
length_penalty
:
float
=
1.0
,
)
->
float
:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
seq_len
=
len
(
tokens
)
if
tokens
[
-
1
]
==
eos_token_id
:
seq_len
-=
1
return
cumulative_logprob
/
(
seq_len
**
length_penalty
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment