Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4d31cd42
Unverified
Commit
4d31cd42
authored
Oct 14, 2024
by
Brendan Wong
Committed by
GitHub
Oct 14, 2024
Browse files
[Frontend] merge beam search implementations (#9296)
parent
473e7b36
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
145 additions
and
234 deletions
+145
-234
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+6
-104
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+17
-109
vllm/engine/protocol.py
vllm/engine/protocol.py
+122
-7
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+0
-7
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+0
-7
No files found.
vllm/engine/async_llm_engine.py
View file @
4d31cd42
...
@@ -7,7 +7,6 @@ from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
...
@@ -7,7 +7,6 @@ from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
from
weakref
import
ReferenceType
from
weakref
import
ReferenceType
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.beam_search
import
BeamSearchSequence
,
create_sort_beams_key_function
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
)
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
...
@@ -15,25 +14,24 @@ from vllm.engine.arg_utils import AsyncEngineArgs
...
@@ -15,25 +14,24 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
,
SchedulerOutputState
from
vllm.engine.llm_engine
import
LLMEngine
,
SchedulerOutputState
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.engine.protocol
import
EngineClient
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
PromptType
,
TokensPrompt
from
vllm.inputs
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
get_guided_decoding_logits_processor
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingRequestOutput
,
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
RequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
collect_from_async_generator
,
deprecate_kwargs
,
from
vllm.utils
import
deprecate_kwargs
,
weak_bind
random_uuid
,
weak_bind
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
@@ -541,7 +539,7 @@ async def build_guided_decoding_logits_processor_async(
...
@@ -541,7 +539,7 @@ async def build_guided_decoding_logits_processor_async(
return
sampling_params
return
sampling_params
class
AsyncLLMEngine
:
class
AsyncLLMEngine
(
EngineClient
)
:
"""An asynchronous wrapper for :class:`LLMEngine`.
"""An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to make it
This class is used to wrap the :class:`LLMEngine` class to make it
...
@@ -1039,102 +1037,6 @@ class AsyncLLMEngine:
...
@@ -1039,102 +1037,6 @@ class AsyncLLMEngine:
):
):
yield
LLMEngine
.
validate_output
(
output
,
RequestOutput
)
yield
LLMEngine
.
validate_output
(
output
,
RequestOutput
)
async
def
beam_search
(
self
,
prompt
:
Union
[
PromptType
,
List
[
int
]],
request_id
:
str
,
params
:
BeamSearchParams
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
beam_width
=
params
.
beam_width
max_tokens
=
params
.
max_tokens
ignore_eos
=
params
.
ignore_eos
temperature
=
params
.
temperature
length_penalty
=
params
.
length_penalty
tokenizer
=
await
self
.
get_tokenizer
()
tokenizedPrompt
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
tokenizedLength
=
len
(
tokenizedPrompt
)
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
)
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
temperature
)
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenizedPrompt
,
cum_logprob
=
0
)]
completed
=
[]
for
_
in
range
(
max_tokens
):
prompts_batch
=
[
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
)
for
beam
in
all_beams
]
tasks
=
[]
request_id
=
f
"beam_search-
{
random_uuid
()
}
"
for
i
,
individual_prompt
in
enumerate
(
prompts_batch
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
task
=
asyncio
.
create_task
(
collect_from_async_generator
(
self
.
generate
(
individual_prompt
,
beam_search_params
,
request_id_item
)))
tasks
.
append
(
task
)
output
=
await
asyncio
.
gather
(
*
tasks
)
output
=
[
x
[
0
]
for
x
in
output
]
logger
.
info
(
output
)
new_beams
=
[]
for
i
,
current_beam
in
enumerate
(
all_beams
):
result
=
output
[
i
]
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
if
token_id
==
tokenizer
.
eos_token_id
and
\
not
ignore_eos
:
completed
.
append
(
new_beam
)
else
:
new_beams
.
append
(
new_beam
)
sorted_beams
=
sorted
(
new_beams
,
key
=
sort_beams_key
,
reverse
=
True
)
all_beams
=
sorted_beams
[:
beam_width
]
completed
.
extend
(
all_beams
)
sorted_completed
=
sorted
(
completed
,
key
=
sort_beams_key
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenizedLength
:])
beam_search_output
=
RequestOutput
(
request_id
=
request_id
,
prompt
=
prompt
,
outputs
=
[
CompletionOutput
(
text
=
beam
.
text
,
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
,
index
=
i
,
logprobs
=
beam
.
cum_logprob
,
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
finished
=
True
,
prompt_token_ids
=
tokenizedPrompt
,
prompt_logprobs
=
None
)
yield
LLMEngine
.
validate_output
(
beam_search_output
,
RequestOutput
)
async
def
encode
(
async
def
encode
(
self
,
self
,
prompt
:
PromptType
,
prompt
:
PromptType
,
...
...
vllm/engine/multiprocessing/client.py
View file @
4d31cd42
...
@@ -12,8 +12,8 @@ from zmq import Frame # type: ignore[attr-defined]
...
@@ -12,8 +12,8 @@ from zmq import Frame # type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
zmq.asyncio
import
Socket
from
vllm
import
PoolingParams
from
vllm
import
PoolingParams
from
vllm.beam_search
import
BeamSearchSequence
,
create_sort_beams_key_function
from
vllm.config
import
DecodingConfig
,
EngineConfig
,
ModelConfig
from
vllm.config
import
DecodingConfig
,
EngineConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -26,18 +26,18 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -26,18 +26,18 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCError
,
RPCProcessRequest
,
RPCError
,
RPCProcessRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
RPCUProfileRequest
)
from
vllm.engine.protocol
import
EngineClient
# yapf: enable
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.inputs
import
PromptType
,
TokensPrompt
from
vllm.inputs
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.
outputs
import
(
Co
mple
tion
Output
,
EmbeddingRequestOutput
,
from
vllm.
model_executor.layers.sampler
import
Sa
mple
r
Output
RequestOutput
)
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.utils
import
(
collect_from_async_generator
,
deprecate_kwargs
,
from
vllm.utils
import
deprecate_kwargs
random_uuid
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -53,7 +53,7 @@ class MQClientClosedError(Exception):
...
@@ -53,7 +53,7 @@ class MQClientClosedError(Exception):
"""
"""
class
MQLLMEngineClient
:
class
MQLLMEngineClient
(
EngineClient
)
:
"""A client wrapper for MQLLMEngine that conforms to the
"""A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol.
EngineClient protocol.
...
@@ -316,7 +316,7 @@ class MQLLMEngineClient:
...
@@ -316,7 +316,7 @@ class MQLLMEngineClient:
or
response
!=
VLLM_RPC_SUCCESS_STR
):
or
response
!=
VLLM_RPC_SUCCESS_STR
):
raise
ValueError
(
error_message
)
raise
ValueError
(
error_message
)
async
def
get_tokenizer
(
self
,
lora_request
:
LoRARequest
):
async
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
...
@@ -344,8 +344,14 @@ class MQLLMEngineClient:
...
@@ -344,8 +344,14 @@ class MQLLMEngineClient:
await
self
.
_send_one_way_rpc_request
(
await
self
.
_send_one_way_rpc_request
(
request
=
RPCAbortRequest
(
request_id
),
socket
=
self
.
input_socket
)
request
=
RPCAbortRequest
(
request_id
),
socket
=
self
.
input_socket
)
async
def
do_log_stats
(
self
):
async
def
do_log_stats
(
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
,
)
->
None
:
"""
Ignore do_log_stats (handled on MQLLMEngine polling)
"""
pass
pass
async
def
check_health
(
self
):
async
def
check_health
(
self
):
...
@@ -444,104 +450,6 @@ class MQLLMEngineClient:
...
@@ -444,104 +450,6 @@ class MQLLMEngineClient:
lora_request
,
trace_headers
,
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
)
prompt_adapter_request
,
priority
)
async
def
beam_search
(
self
,
prompt
:
Union
[
PromptType
,
List
[
int
]],
request_id
:
str
,
params
:
BeamSearchParams
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
beam_width
=
params
.
beam_width
max_tokens
=
params
.
max_tokens
ignore_eos
=
params
.
ignore_eos
temperature
=
params
.
temperature
length_penalty
=
params
.
length_penalty
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
=
None
)
tokenizedPrompt
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
tokenizedLength
=
len
(
tokenizedPrompt
)
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
)
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
temperature
)
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenizedPrompt
,
cum_logprob
=
0
)]
completed
=
[]
for
_
in
range
(
max_tokens
):
prompts_batch
=
[
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
)
for
beam
in
all_beams
]
tasks
=
[]
request_id
=
f
"beam_search-
{
random_uuid
()
}
"
for
i
,
individual_prompt
in
enumerate
(
prompts_batch
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
task
=
asyncio
.
create_task
(
collect_from_async_generator
(
self
.
generate
(
individual_prompt
,
beam_search_params
,
request_id_item
)))
tasks
.
append
(
task
)
output
=
await
asyncio
.
gather
(
*
tasks
)
output
=
[
x
[
0
]
for
x
in
output
]
logger
.
info
(
output
)
new_beams
=
[]
for
i
,
current_beam
in
enumerate
(
all_beams
):
result
=
output
[
i
]
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
if
token_id
==
tokenizer
.
eos_token_id
and
\
not
ignore_eos
:
completed
.
append
(
new_beam
)
else
:
new_beams
.
append
(
new_beam
)
sorted_beams
=
sorted
(
new_beams
,
key
=
sort_beams_key
,
reverse
=
True
)
all_beams
=
sorted_beams
[:
beam_width
]
completed
.
extend
(
all_beams
)
sorted_completed
=
sorted
(
completed
,
key
=
sort_beams_key
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenizedLength
:])
beam_search_output
=
RequestOutput
(
request_id
=
request_id
,
prompt
=
prompt
,
outputs
=
[
CompletionOutput
(
text
=
beam
.
text
,
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
,
index
=
i
,
logprobs
=
beam
.
cum_logprob
,
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
finished
=
True
,
prompt_token_ids
=
tokenizedPrompt
,
prompt_logprobs
=
None
)
logger
.
info
(
beam_search_output
)
yield
beam_search_output
@
overload
# DEPRECATED
@
overload
# DEPRECATED
def
encode
(
def
encode
(
self
,
self
,
...
...
vllm/engine/protocol.py
View file @
4d31cd42
from
typing
import
(
AsyncGenerator
,
List
,
Mapping
,
Optional
,
Protocol
,
import
asyncio
runtime_checkable
)
from
abc
import
ABC
,
abstractmethod
from
typing
import
AsyncGenerator
,
List
,
Mapping
,
Optional
,
Union
from
vllm.beam_search
import
BeamSearchSequence
,
create_sort_beams_key_function
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
PromptType
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingRequestOutput
,
RequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
collect_from_async_generator
,
random_uuid
logger
=
init_logger
(
__name__
)
@
runtime_checkable
class
EngineClient
(
Protocol
):
class
EngineClient
(
ABC
):
"""Protocol class for Clients to Engine"""
"""Protocol class for Clients to Engine"""
@
property
@
property
@
abstractmethod
def
is_running
(
self
)
->
bool
:
def
is_running
(
self
)
->
bool
:
...
...
@
property
@
property
@
abstractmethod
def
is_stopped
(
self
)
->
bool
:
def
is_stopped
(
self
)
->
bool
:
...
...
@
property
@
property
@
abstractmethod
def
errored
(
self
)
->
bool
:
def
errored
(
self
)
->
bool
:
...
...
@
property
@
property
@
abstractmethod
def
dead_error
(
self
)
->
BaseException
:
def
dead_error
(
self
)
->
BaseException
:
...
...
@
abstractmethod
def
generate
(
def
generate
(
self
,
self
,
prompt
:
PromptType
,
prompt
:
PromptType
,
...
@@ -46,6 +57,101 @@ class EngineClient(Protocol):
...
@@ -46,6 +57,101 @@ class EngineClient(Protocol):
"""Generate outputs for a request."""
"""Generate outputs for a request."""
...
...
async
def
beam_search
(
self
,
prompt
:
Union
[
PromptType
,
List
[
int
]],
request_id
:
str
,
params
:
BeamSearchParams
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
beam_width
=
params
.
beam_width
max_tokens
=
params
.
max_tokens
ignore_eos
=
params
.
ignore_eos
temperature
=
params
.
temperature
length_penalty
=
params
.
length_penalty
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
=
None
)
tokenizedPrompt
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
tokenizedLength
=
len
(
tokenizedPrompt
)
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
)
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
temperature
)
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenizedPrompt
,
cum_logprob
=
0
)]
completed
=
[]
for
_
in
range
(
max_tokens
):
prompts_batch
=
[
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
)
for
beam
in
all_beams
]
tasks
=
[]
request_id
=
f
"beam_search-
{
random_uuid
()
}
"
for
i
,
individual_prompt
in
enumerate
(
prompts_batch
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
task
=
asyncio
.
create_task
(
collect_from_async_generator
(
self
.
generate
(
individual_prompt
,
beam_search_params
,
request_id_item
)))
tasks
.
append
(
task
)
output
=
await
asyncio
.
gather
(
*
tasks
)
output
=
[
x
[
0
]
for
x
in
output
]
new_beams
=
[]
for
i
,
current_beam
in
enumerate
(
all_beams
):
result
=
output
[
i
]
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
if
token_id
==
tokenizer
.
eos_token_id
and
\
not
ignore_eos
:
completed
.
append
(
new_beam
)
else
:
new_beams
.
append
(
new_beam
)
sorted_beams
=
sorted
(
new_beams
,
key
=
sort_beams_key
,
reverse
=
True
)
all_beams
=
sorted_beams
[:
beam_width
]
completed
.
extend
(
all_beams
)
sorted_completed
=
sorted
(
completed
,
key
=
sort_beams_key
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenizedLength
:])
beam_search_output
=
RequestOutput
(
request_id
=
request_id
,
prompt
=
prompt
,
outputs
=
[
CompletionOutput
(
text
=
beam
.
text
,
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
,
index
=
i
,
logprobs
=
beam
.
cum_logprob
,
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
finished
=
True
,
prompt_token_ids
=
tokenizedPrompt
,
prompt_logprobs
=
None
)
yield
beam_search_output
@
abstractmethod
def
encode
(
def
encode
(
self
,
self
,
prompt
:
PromptType
,
prompt
:
PromptType
,
...
@@ -58,6 +164,7 @@ class EngineClient(Protocol):
...
@@ -58,6 +164,7 @@ class EngineClient(Protocol):
"""Generate outputs for a request from an embedding model."""
"""Generate outputs for a request from an embedding model."""
...
...
@
abstractmethod
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
"""Abort a request.
...
@@ -65,14 +172,17 @@ class EngineClient(Protocol):
...
@@ -65,14 +172,17 @@ class EngineClient(Protocol):
request_id: The unique id of the request.
request_id: The unique id of the request.
"""
"""
@
abstractmethod
async
def
get_model_config
(
self
)
->
ModelConfig
:
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
"""Get the model configuration of the vLLM engine."""
...
...
@
abstractmethod
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
...
...
"""Get the decoding configuration of the vLLM engine."""
"""Get the decoding configuration of the vLLM engine."""
@
abstractmethod
async
def
get_tokenizer
(
async
def
get_tokenizer
(
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -80,9 +190,11 @@ class EngineClient(Protocol):
...
@@ -80,9 +190,11 @@ class EngineClient(Protocol):
"""Get the appropriate tokenizer for the request"""
"""Get the appropriate tokenizer for the request"""
...
...
@
abstractmethod
async
def
is_tracing_enabled
(
self
)
->
bool
:
async
def
is_tracing_enabled
(
self
)
->
bool
:
...
...
@
abstractmethod
async
def
do_log_stats
(
async
def
do_log_stats
(
self
,
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
...
@@ -90,14 +202,17 @@ class EngineClient(Protocol):
...
@@ -90,14 +202,17 @@ class EngineClient(Protocol):
)
->
None
:
)
->
None
:
...
...
@
abstractmethod
async
def
check_health
(
self
)
->
None
:
async
def
check_health
(
self
)
->
None
:
"""Raise if unhealthy"""
"""Raise if unhealthy"""
...
...
@
abstractmethod
async
def
start_profile
(
self
)
->
None
:
async
def
start_profile
(
self
)
->
None
:
"""Start profiling the engine"""
"""Start profiling the engine"""
...
...
@
abstractmethod
async
def
stop_profile
(
self
)
->
None
:
async
def
stop_profile
(
self
)
->
None
:
"""Start profiling the engine"""
"""Start profiling the engine"""
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
4d31cd42
...
@@ -9,8 +9,6 @@ from typing import Union
...
@@ -9,8 +9,6 @@ from typing import Union
from
fastapi
import
Request
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
apply_hf_chat_template
,
apply_hf_chat_template
,
...
@@ -237,11 +235,6 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -237,11 +235,6 @@ class OpenAIServingChat(OpenAIServing):
log_tracing_disabled_warning
()
log_tracing_disabled_warning
()
if
isinstance
(
sampling_params
,
BeamSearchParams
):
if
isinstance
(
sampling_params
,
BeamSearchParams
):
assert
isinstance
(
self
.
engine_client
,
(
AsyncLLMEngine
,
MQLLMEngineClient
)),
\
"Beam search is only supported with"
\
"AsyncLLMEngine and MQLLMEngineClient."
result_generator
=
self
.
engine_client
.
beam_search
(
result_generator
=
self
.
engine_client
.
beam_search
(
engine_inputs
[
'prompt_token_ids'
],
engine_inputs
[
'prompt_token_ids'
],
request_id
,
request_id
,
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
4d31cd42
...
@@ -8,8 +8,6 @@ from typing import Tuple, Union, cast
...
@@ -8,8 +8,6 @@ from typing import Tuple, Union, cast
from
fastapi
import
Request
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
...
@@ -151,11 +149,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -151,11 +149,6 @@ class OpenAIServingCompletion(OpenAIServing):
log_tracing_disabled_warning
()
log_tracing_disabled_warning
()
if
isinstance
(
sampling_params
,
BeamSearchParams
):
if
isinstance
(
sampling_params
,
BeamSearchParams
):
assert
isinstance
(
self
.
engine_client
,
(
AsyncLLMEngine
,
MQLLMEngineClient
)),
\
"Beam search is only supported with"
\
"AsyncLLMEngine and MQLLMEngineClient."
generator
=
self
.
engine_client
.
beam_search
(
generator
=
self
.
engine_client
.
beam_search
(
prompt_inputs
[
"prompt_token_ids"
],
prompt_inputs
[
"prompt_token_ids"
],
request_id_item
,
request_id_item
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment