Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
539aa992
Commit
539aa992
authored
Sep 27, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.2' into v0.6.2-dev
parents
93872128
7193774b
Changes
383
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1876 additions
and
1121 deletions
+1876
-1121
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+109
-15
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+39
-8
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+54
-31
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+180
-142
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+75
-0
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+507
-0
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+391
-0
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+5
-4
vllm/engine/protocol.py
vllm/engine/protocol.py
+4
-4
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+0
-1
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+22
-6
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+20
-19
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+200
-38
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+179
-139
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+30
-4
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+5
-0
vllm/entrypoints/openai/rpc/client.py
vllm/entrypoints/openai/rpc/client.py
+0
-451
vllm/entrypoints/openai/rpc/server.py
vllm/entrypoints/openai/rpc/server.py
+0
-237
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+7
-2
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+49
-20
No files found.
vllm/distributed/parallel_state.py
View file @
539aa992
...
...
@@ -21,11 +21,12 @@ If you only need to use the distributed environment without model/pipeline
"""
import
contextlib
import
pickle
import
weakref
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
multiprocessing
import
shared_memory
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
unittest.mock
import
patch
import
torch
...
...
@@ -34,6 +35,8 @@ from torch.distributed import Backend, ProcessGroup
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
supports_custom_op
@
dataclass
...
...
@@ -69,6 +72,59 @@ def _split_tensor_dict(
return
metadata_list
,
tensor_list
_group_name_counter
:
Dict
[
str
,
int
]
=
{}
def
_get_unique_name
(
name
:
str
)
->
str
:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if
name
not
in
_group_name_counter
:
_group_name_counter
[
name
]
=
0
newname
=
f
"
{
name
}
:
{
_group_name_counter
[
name
]
}
"
_group_name_counter
[
name
]
+=
1
return
newname
_groups
:
Dict
[
str
,
Callable
[[],
"GroupCoordinator"
]]
=
{}
def
_register_group
(
group
:
"GroupCoordinator"
)
->
None
:
# looks like Python 3.8 does not understand `ReferenceType`
_groups
[
group
.
unique_name
]
=
weakref
.
ref
(
group
)
# type: ignore
if
supports_custom_op
():
@
torch
.
library
.
custom_op
(
"vllm::inplace_all_reduce"
,
mutates_args
=
[
"tensor"
])
def
inplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
group
.
_all_reduce
(
tensor
)
@
inplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
return
@
torch
.
library
.
custom_op
(
"vllm::outplace_all_reduce"
,
mutates_args
=
[])
def
outplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce
(
tensor
)
@
outplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
tensor
)
class
GroupCoordinator
:
"""
PyTorch ProcessGroup wrapper for a group of processes.
...
...
@@ -111,7 +167,11 @@ class GroupCoordinator:
use_custom_allreduce
:
bool
,
use_tpu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
):
group_name
=
group_name
or
"anonymous"
self
.
unique_name
=
_get_unique_name
(
group_name
)
_register_group
(
self
)
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
local_rank
=
local_rank
...
...
@@ -134,7 +194,7 @@ class GroupCoordinator:
assert
self
.
cpu_group
is
not
None
assert
self
.
device_group
is
not
None
if
torch
.
cuda
.
is_availabl
e
():
if
current_platform
.
is_cuda_alik
e
():
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
...
...
@@ -149,28 +209,24 @@ class GroupCoordinator:
from
vllm.distributed.device_communicators.pynccl
import
(
PyNcclCommunicator
)
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
=
None
if
use_pynccl
and
self
.
world_size
>
1
:
self
.
pynccl_comm
=
PyNcclCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
else
:
self
.
pynccl_comm
=
None
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
=
None
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
else
:
self
.
ca_comm
=
None
from
vllm.distributed.device_communicators.tpu_communicator
import
(
TpuCommunicator
)
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
=
None
if
use_tpu_communicator
and
self
.
world_size
>
1
:
self
.
tpu_communicator
=
TpuCommunicator
(
group
=
self
.
cpu_group
)
...
...
@@ -264,16 +320,49 @@ class GroupCoordinator:
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
if
not
supports_custom_op
():
return
self
.
_all_reduce
(
input_
)
if
self
.
tpu_communicator
is
not
None
and
\
not
self
.
tpu_communicator
.
disabled
:
# TPU handles Dynamo with its own logic.
return
self
.
_all_reduce
(
input_
)
if
self
.
ca_comm
is
not
None
and
self
.
ca_comm
.
should_custom_ar
(
input_
):
return
torch
.
ops
.
vllm
.
outplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
else
:
torch
.
ops
.
vllm
.
inplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
return
input_
def
_all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
The actual all-reduce implementation.
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
ca_comm
=
self
.
ca_comm
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
...
...
@@ -758,6 +847,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_pynccl
=
False
,
use_custom_allreduce
=
False
,
use_tpu_communicator
=
False
,
group_name
=
"world"
,
)
...
...
@@ -767,6 +857,7 @@ def init_model_parallel_group(
backend
:
str
,
use_custom_allreduce
:
Optional
[
bool
]
=
None
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
)
->
GroupCoordinator
:
if
use_custom_allreduce
is
None
:
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
...
...
@@ -778,6 +869,7 @@ def init_model_parallel_group(
use_custom_allreduce
=
use_custom_allreduce
,
use_tpu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
group_name
=
group_name
,
)
...
...
@@ -931,7 +1023,8 @@ def initialize_model_parallel(
_TP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
use_message_queue_broadcaster
=
True
)
use_message_queue_broadcaster
=
True
,
group_name
=
"tp"
)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups
:
int
=
(
world_size
//
...
...
@@ -947,7 +1040,8 @@ def initialize_model_parallel(
_PP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
use_custom_allreduce
=
False
)
use_custom_allreduce
=
False
,
group_name
=
"pp"
)
def
ensure_model_parallel_initialized
(
...
...
vllm/engine/arg_utils.py
View file @
539aa992
...
...
@@ -44,22 +44,36 @@ def nullable_str(val: str):
def
nullable_kvs
(
val
:
str
)
->
Optional
[
Mapping
[
str
,
int
]]:
"""Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.
Args:
val: String value to be parsed.
Returns:
Dictionary with parsed values.
"""
if
len
(
val
)
==
0
:
return
None
out_dict
:
Dict
[
str
,
int
]
=
{}
for
item
in
val
.
split
(
","
):
try
:
key
,
value
=
item
.
split
(
"="
)
except
TypeError
as
exc
:
msg
=
"Each item should be in the form KEY=VALUE"
raise
ValueError
(
msg
)
from
exc
kv_parts
=
[
part
.
lower
().
strip
()
for
part
in
item
.
split
(
"="
)]
if
len
(
kv_parts
)
!=
2
:
raise
argparse
.
Argument
TypeError
(
"Each item should be in the form KEY=VALUE"
)
key
,
value
=
kv_parts
try
:
out_dict
[
key
]
=
int
(
value
)
parsed_value
=
int
(
value
)
except
ValueError
as
exc
:
msg
=
f
"Failed to parse value of item
{
key
}
=
{
value
}
"
raise
ValueError
(
msg
)
from
exc
raise
argparse
.
ArgumentTypeError
(
msg
)
from
exc
if
key
in
out_dict
and
out_dict
[
key
]
!=
parsed_value
:
raise
argparse
.
ArgumentTypeError
(
f
"Conflicting values specified for key:
{
key
}
"
)
out_dict
[
key
]
=
parsed_value
return
out_dict
...
...
@@ -131,6 +145,7 @@ class EngineArgs:
max_cpu_loras
:
Optional
[
int
]
=
None
device
:
str
=
'auto'
num_scheduler_steps
:
int
=
1
multi_step_stream_outputs
:
bool
=
False
ray_workers_use_nsight
:
bool
=
False
num_gpu_blocks_override
:
Optional
[
int
]
=
None
num_lookahead_slots
:
int
=
0
...
...
@@ -161,6 +176,7 @@ class EngineArgs:
collect_detailed_traces
:
Optional
[
str
]
=
None
disable_async_output_proc
:
bool
=
False
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
...
...
@@ -458,7 +474,10 @@ class EngineArgs:
default
=
EngineArgs
.
max_seq_len_to_capture
,
help
=
'Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.'
)
'larger than this, we fall back to eager mode. '
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.'
)
parser
.
add_argument
(
'--disable-custom-all-reduce'
,
action
=
'store_true'
,
default
=
EngineArgs
.
disable_custom_all_reduce
,
...
...
@@ -496,6 +515,12 @@ class EngineArgs:
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to 1 for '
'each modality.'
))
parser
.
add_argument
(
'--mm-processor-kwargs'
,
default
=
None
,
type
=
json
.
loads
,
help
=
(
'Overrides for the multimodal input mapping/processing,'
'e.g., image processor. For example: {"num_crops": 4}.'
))
# LoRA related configs
parser
.
add_argument
(
'--enable-lora'
,
...
...
@@ -571,6 +596,10 @@ class EngineArgs:
help
=
(
'Maximum number of forward steps per '
'scheduler call.'
))
parser
.
add_argument
(
'--multi-step-stream-outputs'
,
action
=
'store_true'
,
help
=
'If True, then multi-step will stream outputs for every step'
)
parser
.
add_argument
(
'--scheduler-delay-factor'
,
type
=
float
,
...
...
@@ -805,6 +834,7 @@ class EngineArgs:
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
override_neuron_config
=
self
.
override_neuron_config
,
config_format
=
self
.
config_format
,
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
)
def
create_load_config
(
self
)
->
LoadConfig
:
...
...
@@ -974,6 +1004,7 @@ class EngineArgs:
is_multimodal_model
=
model_config
.
is_multimodal_model
,
preemption_mode
=
self
.
preemption_mode
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
multi_step_stream_outputs
=
self
.
multi_step_stream_outputs
,
send_delta_data
=
(
envs
.
VLLM_USE_RAY_SPMD_WORKER
and
parallel_config
.
use_ray
),
)
...
...
vllm/engine/async_llm_engine.py
View file @
539aa992
import
asyncio
import
time
import
weakref
from
functools
import
partial
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
weakref
import
ReferenceType
import
vllm.envs
as
envs
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
...
...
@@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
weak_bind
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
...
@@ -450,9 +453,6 @@ class AsyncLLMEngine:
method yields the outputs from the :class:`LLMEngine` to the caller.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
...
...
@@ -463,23 +463,22 @@ class AsyncLLMEngine:
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
def
__init__
(
self
,
worker_use_ray
:
bool
,
*
args
,
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
**
kwargs
)
->
None
:
self
.
worker_use_ray
=
worker_use_ray
self
.
log_requests
=
log_requests
self
.
engine
=
self
.
_engine_class
(
*
args
,
**
kwargs
)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
self
.
use_process_request_outputs_callback
=
True
self
.
use_process_request_outputs_callback
=
(
self
.
engine
.
model_config
.
use_async_output_proc
)
if
self
.
use_process_request_outputs_callback
:
self
.
engine
.
process_request_outputs_callback
=
\
self
.
process_request_outputs
weak_bind
(
self
.
process_request_outputs
)
self
.
background_loop
:
Optional
[
asyncio
.
Future
]
=
None
# We need to keep a reference to unshielded
...
...
@@ -492,6 +491,11 @@ class AsyncLLMEngine:
# Lazy initialized fields
self
.
_request_tracker
:
RequestTracker
def
__del__
(
self
):
if
rt
:
=
getattr
(
self
,
"request_tracker"
,
None
):
# Wake up engine loop so that it will exit cleanly
rt
.
new_requests_event
.
set
()
@
classmethod
def
_get_executor_cls
(
cls
,
engine_config
:
EngineConfig
)
->
Type
[
ExecutorAsyncBase
]:
...
...
@@ -502,15 +506,12 @@ class AsyncLLMEngine:
raise
TypeError
(
"distributed_executor_backend must be a subclass of "
f
"ExecutorAsyncBase. Got
{
distributed_executor_backend
}
."
)
if
distributed_executor_backend
.
uses_ray
:
# type: ignore
initialize_ray_cluster
(
engine_config
.
parallel_config
)
executor_class
=
distributed_executor_backend
elif
engine_config
.
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
if
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_tpu_executor
import
RayTPUExecutorAsync
executor_class
=
RayTPUExecutorAsync
else
:
...
...
@@ -531,11 +532,9 @@ class AsyncLLMEngine:
from
vllm.executor.xpu_executor
import
XPUExecutorAsync
executor_class
=
XPUExecutorAsync
elif
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_xpu_executor
import
RayXPUExecutorAsync
executor_class
=
RayXPUExecutorAsync
elif
distributed_executor_backend
==
"mp"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.multiproc_xpu_executor
import
(
MultiprocessingXPUExecutorAsync
)
executor_class
=
MultiprocessingXPUExecutorAsync
...
...
@@ -543,7 +542,6 @@ class AsyncLLMEngine:
raise
RuntimeError
(
"Not supported distributed execution model on XPU device."
)
elif
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutorAsync
executor_class
=
RayGPUExecutorAsync
elif
distributed_executor_backend
==
"mp"
:
...
...
@@ -559,19 +557,23 @@ class AsyncLLMEngine:
def
from_engine_args
(
cls
,
engine_args
:
AsyncEngineArgs
,
engine_config
:
Optional
[
EngineConfig
]
=
None
,
start_engine_loop
:
bool
=
True
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"AsyncLLMEngine"
:
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config
=
engine_args
.
create_engine_config
()
if
engine_config
is
None
:
engine_config
=
engine_args
.
create_engine_config
()
executor_class
=
cls
.
_get_executor_cls
(
engine_config
)
if
executor_class
.
uses_ray
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
# Create the async LLM engine.
engine
=
cls
(
executor_class
.
uses_ray
,
**
engine_config
.
to_dict
(),
executor_class
=
executor_class
,
log_requests
=
not
engine_args
.
disable_log_requests
,
...
...
@@ -599,9 +601,12 @@ class AsyncLLMEngine:
return
self
.
_errored_with
is
not
None
@
property
def
limit_concurrency
(
self
)
->
Optional
[
int
]:
"""Maximum number of concurrently running requests."""
return
None
def
dead_error
(
self
)
->
BaseException
:
return
AsyncEngineDeadError
(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError)."
)
def
set_errored
(
self
,
exc
:
Exception
)
->
None
:
self
.
_errored_with
=
exc
...
...
@@ -628,7 +633,7 @@ class AsyncLLMEngine:
self
.
_request_tracker
=
RequestTracker
()
self
.
_background_loop_unshielded
=
asyncio
.
get_event_loop
(
).
create_task
(
self
.
run_engine_loop
())
).
create_task
(
self
.
run_engine_loop
(
weakref
.
ref
(
self
)
))
self
.
_background_loop_unshielded
.
add_done_callback
(
partial
(
_log_task_completion
,
error_callback
=
self
.
_error_callback
))
self
.
background_loop
=
asyncio
.
shield
(
self
.
_background_loop_unshielded
)
...
...
@@ -698,9 +703,16 @@ class AsyncLLMEngine:
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
self
.
engine
.
abort_request
(
request_ids
)
async
def
run_engine_loop
(
self
):
@
staticmethod
async
def
run_engine_loop
(
engine_ref
:
ReferenceType
):
"""We use a weakref to the engine so that the running loop
doesn't prevent the engine being garbage collected."""
engine
:
Optional
[
"AsyncLLMEngine"
]
=
engine_ref
()
if
not
engine
:
return
pipeline_parallel_size
=
\
self
.
engine
.
parallel_config
.
pipeline_parallel_size
engine
.
engine
.
parallel_config
.
pipeline_parallel_size
has_requests_in_progress
=
[
False
]
*
pipeline_parallel_size
while
True
:
if
not
any
(
has_requests_in_progress
):
...
...
@@ -711,11 +723,21 @@ class AsyncLLMEngine:
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
await
self
.
engine
.
stop_remote_worker_execution_loop_async
()
await
self
.
_request_tracker
.
wait_for_new_requests
()
await
engine
.
engine
.
stop_remote_worker_execution_loop_async
()
request_tracker
=
engine
.
_request_tracker
# Allow engine to be garbage collected while
# waiting for new requests
del
engine
await
asyncio
.
sleep
(
0
)
if
engine_ref
()
is
None
:
return
await
request_tracker
.
wait_for_new_requests
()
engine
=
engine_ref
()
if
not
engine
:
return
logger
.
debug
(
"Got new requests!"
)
requests_in_progress
=
[
asyncio
.
create_task
(
self
.
engine_step
(
ve
))
asyncio
.
create_task
(
engine
.
engine_step
(
ve
))
for
ve
in
range
(
pipeline_parallel_size
)
]
has_requests_in_progress
=
[
True
]
*
pipeline_parallel_size
...
...
@@ -733,19 +755,20 @@ class AsyncLLMEngine:
result
=
task
.
result
()
virtual_engine
=
requests_in_progress
.
index
(
task
)
has_unfinished_requests
=
(
self
.
engine
.
has_unfinished_requests_for_virtual_engine
(
engine
.
engine
.
has_unfinished_requests_for_virtual_engine
(
virtual_engine
))
if
result
or
has_unfinished_requests
:
requests_in_progress
[
virtual_engine
]
=
(
asyncio
.
create_task
(
self
.
engine_step
(
virtual_engine
)))
engine
.
engine_step
(
virtual_engine
)))
has_requests_in_progress
[
virtual_engine
]
=
True
else
:
has_requests_in_progress
[
virtual_engine
]
=
False
except
asyncio
.
TimeoutError
as
exc
:
logger
.
error
(
"Engine iteration timed out. This should never happen!"
)
self
.
set_errored
(
exc
)
engine
.
set_errored
(
exc
)
raise
await
asyncio
.
sleep
(
0
)
...
...
@@ -806,7 +829,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
Yields:
...
...
@@ -1022,7 +1045,7 @@ class AsyncLLMEngine:
async
def
start_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if
type
(
self
.
engine
.
model_executor
)
==
GPUExecutorAsync
:
if
type
(
self
.
engine
.
model_executor
)
==
GPUExecutorAsync
:
# noqa: E721
self
.
engine
.
model_executor
.
start_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"start_profile"
)
...
...
@@ -1030,7 +1053,7 @@ class AsyncLLMEngine:
async
def
stop_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if
type
(
self
.
engine
.
model_executor
)
==
GPUExecutorAsync
:
if
type
(
self
.
engine
.
model_executor
)
==
GPUExecutorAsync
:
# noqa: E721
self
.
engine
.
model_executor
.
stop_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"stop_profile"
)
vllm/engine/llm_engine.py
View file @
539aa992
import
functools
import
time
from
collections
import
deque
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
...
...
@@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
Counter
,
Device
from
vllm.utils
import
Counter
,
Device
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
...
...
@@ -95,7 +95,7 @@ class OutputData(NamedTuple):
class
SchedulerContext
:
def
__init__
(
self
):
def
__init__
(
self
,
multi_step_stream_outputs
:
bool
=
False
):
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
...
...
@@ -103,6 +103,8 @@ class SchedulerContext:
List
[
SequenceGroupMetadata
]]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
self
.
multi_step_stream_outputs
:
bool
=
multi_step_stream_outputs
def
append_output
(
self
,
outputs
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduler_outputs
:
SchedulerOutputs
,
is_async
:
bool
,
...
...
@@ -144,7 +146,7 @@ class LLMEngine:
decoding.
executor_class: The model executor class for managing distributed
execution.
prompt_adapter_config (Optional): The configuration related to serving
prompt_adapter_config (Optional): The configuration related to serving
prompt adapters.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection.
...
...
@@ -219,6 +221,7 @@ class LLMEngine:
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
)
->
None
:
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
...
...
@@ -234,8 +237,9 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s)"
,
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
"enable_prefix_caching=%s, use_async_output_proc=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)"
,
VLLM_VERSION
,
model_config
.
model
,
speculative_config
,
...
...
@@ -266,8 +270,11 @@ class LLMEngine:
model_config
.
served_model_name
,
scheduler_config
.
use_v2_block_manager
,
scheduler_config
.
num_scheduler_steps
,
scheduler_config
.
multi_step_stream_outputs
,
cache_config
.
enable_prefix_caching
,
model_config
.
use_async_output_proc
,
use_cached_outputs
,
model_config
.
mm_processor_kwargs
,
)
# TODO(woosuk): Print more configs in debug mode.
from
vllm.plugins
import
load_general_plugins
...
...
@@ -286,6 +293,7 @@ class LLMEngine:
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
)
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
...
...
@@ -327,136 +335,134 @@ class LLMEngine:
observability_config
=
self
.
observability_config
,
)
init_success
=
False
try
:
if
not
self
.
model_config
.
embedding_mode
:
self
.
_initialize_kv_caches
()
# If usage stat is enabled, collect relevant info.
if
is_usage_stats_enabled
():
from
vllm.model_executor.model_loader
import
(
get_architecture_class_name
)
usage_message
.
report_usage
(
get_architecture_class_name
(
model_config
),
usage_context
,
extra_kvs
=
{
# Common configuration
"dtype"
:
str
(
model_config
.
dtype
),
"tensor_parallel_size"
:
parallel_config
.
tensor_parallel_size
,
"block_size"
:
cache_config
.
block_size
,
"gpu_memory_utilization"
:
cache_config
.
gpu_memory_utilization
,
# Quantization
"quantization"
:
model_config
.
quantization
,
"kv_cache_dtype"
:
str
(
cache_config
.
cache_dtype
),
# Feature flags
"enable_lora"
:
bool
(
lora_config
),
"enable_prompt_adapter"
:
bool
(
prompt_adapter_config
),
"enable_prefix_caching"
:
cache_config
.
enable_prefix_caching
,
"enforce_eager"
:
model_config
.
enforce_eager
,
"disable_custom_all_reduce"
:
parallel_config
.
disable_custom_all_reduce
,
})
if
self
.
tokenizer
:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self
.
tokenizer
.
ping
()
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
if
not
self
.
model_config
.
embedding_mode
:
self
.
_initialize_kv_caches
()
# If usage stat is enabled, collect relevant info.
if
is_usage_stats_enabled
():
from
vllm.model_executor.model_loader
import
(
get_architecture_class_name
)
usage_message
.
report_usage
(
get_architecture_class_name
(
model_config
),
usage_context
,
extra_kvs
=
{
# Common configuration
"dtype"
:
str
(
model_config
.
dtype
),
"tensor_parallel_size"
:
parallel_config
.
tensor_parallel_size
,
"block_size"
:
cache_config
.
block_size
,
"gpu_memory_utilization"
:
cache_config
.
gpu_memory_utilization
,
# Quantization
"quantization"
:
model_config
.
quantization
,
"kv_cache_dtype"
:
str
(
cache_config
.
cache_dtype
),
# Feature flags
"enable_lora"
:
bool
(
lora_config
),
"enable_prompt_adapter"
:
bool
(
prompt_adapter_config
),
"enable_prefix_caching"
:
cache_config
.
enable_prefix_caching
,
"enforce_eager"
:
model_config
.
enforce_eager
,
"disable_custom_all_reduce"
:
parallel_config
.
disable_custom_all_reduce
,
})
self
.
scheduler_contexts
=
[
SchedulerContext
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
if
self
.
tokenizer
:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self
.
tokenizer
.
ping
()
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
scheduler_contexts
=
[
SchedulerContext
(
multi_step_stream_outputs
=
self
.
scheduler_config
.
multi_step_stream_outputs
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
if
model_config
.
use_async_output_proc
:
process_model_outputs
=
weak_bind
(
self
.
_process_model_outputs
)
self
.
async_callbacks
=
[
functools
.
partial
(
self
.
_
process_model_outputs
,
ctx
=
self
.
scheduler_contexts
[
v_id
])
partial
(
process_model_outputs
,
ctx
=
self
.
scheduler_contexts
[
v_id
])
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
else
:
self
.
async_callbacks
=
[]
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self
.
process_request_outputs_callback
:
Optional
[
Callable
]
=
None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self
.
scheduler
=
[
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
,
parallel_config
.
pipeline_parallel_size
,
self
.
async_callbacks
[
v_id
]
if
model_config
.
use_async_output_proc
else
None
)
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
]
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self
.
process_request_outputs_callback
=
None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self
.
scheduler
=
[
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
,
parallel_config
.
pipeline_parallel_size
,
self
.
async_callbacks
[
v_id
]
if
model_config
.
use_async_output_proc
else
None
)
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
]
# Metric Logging.
if
self
.
log_stats
:
if
stat_loggers
is
not
None
:
self
.
stat_loggers
=
stat_loggers
else
:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from
vllm.engine.metrics
import
(
LoggingStatLogger
,
PrometheusStatLogger
)
self
.
stat_loggers
=
{
"logging"
:
LoggingStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
),
"prometheus"
:
PrometheusStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
labels
=
dict
(
model_name
=
model_config
.
served_model_name
),
max_model_len
=
self
.
model_config
.
max_model_len
),
}
self
.
stat_loggers
[
"prometheus"
].
info
(
"cache_config"
,
self
.
cache_config
)
self
.
tracer
=
None
if
self
.
observability_config
.
otlp_traces_endpoint
:
self
.
tracer
=
init_tracer
(
"vllm.llm_engine"
,
self
.
observability_config
.
otlp_traces_endpoint
)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self
.
output_processor
=
(
SequenceGroupOutputProcessor
.
create_output_processor
(
self
.
scheduler_config
,
self
.
detokenizer
,
self
.
scheduler
,
self
.
seq_counter
,
# Metric Logging.
if
self
.
log_stats
:
if
stat_loggers
is
not
None
:
self
.
stat_loggers
=
stat_loggers
else
:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from
vllm.engine.metrics
import
(
LoggingStatLogger
,
PrometheusStatLogger
)
self
.
stat_loggers
=
{
"logging"
:
LoggingStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
),
"prometheus"
:
PrometheusStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
labels
=
dict
(
model_name
=
model_config
.
served_model_name
),
max_model_len
=
self
.
model_config
.
max_model_len
),
}
self
.
stat_loggers
[
"prometheus"
].
info
(
"cache_config"
,
self
.
cache_config
)
self
.
tracer
=
None
if
self
.
observability_config
.
otlp_traces_endpoint
:
self
.
tracer
=
init_tracer
(
"vllm.llm_engine"
,
self
.
observability_config
.
otlp_traces_endpoint
)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self
.
output_processor
=
(
SequenceGroupOutputProcessor
.
create_output_processor
(
self
.
scheduler_config
,
self
.
detokenizer
,
self
.
scheduler
,
self
.
seq_counter
,
get_tokenizer_for_seq
,
stop_checker
=
StopChecker
(
self
.
scheduler_config
.
max_model_len
,
get_tokenizer_for_seq
,
stop_checker
=
StopChecker
(
self
.
scheduler_config
.
max_model_len
,
get_tokenizer_for_seq
,
),
))
init_success
=
True
finally
:
if
not
init_success
:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self
.
model_executor
.
shutdown
()
),
))
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
...
...
@@ -625,6 +631,7 @@ class LLMEngine:
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
None
:
self
.
_validate_model_inputs
(
processed_inputs
)
# Create the sequences.
...
...
@@ -655,7 +662,8 @@ class LLMEngine:
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
elif
isinstance
(
params
,
PoolingParams
):
seq_group
=
self
.
_create_sequence_group_with_pooling
(
request_id
,
...
...
@@ -664,7 +672,8 @@ class LLMEngine:
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
else
:
raise
ValueError
(
"Either SamplingParams or PoolingParams must be provided."
)
...
...
@@ -689,6 +698,7 @@ class LLMEngine:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
"""Add a request to the engine's request pool.
...
...
@@ -707,6 +717,8 @@ class LLMEngine:
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Details:
- Set arrival_time to the current time if it is None.
...
...
@@ -735,6 +747,11 @@ class LLMEngine:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
if
priority
>
0
and
not
self
.
scheduler_config
.
policy
==
"priority"
:
raise
ValueError
(
f
"Got priority
{
priority
}
but "
"Priority scheduling is not enabled."
)
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
...
...
@@ -754,6 +771,7 @@ class LLMEngine:
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
priority
=
priority
,
)
def
_create_sequence_group_with_sampling
(
...
...
@@ -766,6 +784,7 @@ class LLMEngine:
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
priority
:
int
=
0
,
)
->
SequenceGroup
:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs
=
self
.
get_model_config
().
max_logprobs
...
...
@@ -792,7 +811,8 @@ class LLMEngine:
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
return
seq_group
...
...
@@ -805,6 +825,7 @@ class LLMEngine:
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
encoder_seq
:
Optional
[
Sequence
]
=
None
,
priority
:
int
=
0
,
)
->
SequenceGroup
:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
...
...
@@ -817,7 +838,8 @@ class LLMEngine:
lora_request
=
lora_request
,
pooling_params
=
pooling_params
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
return
seq_group
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
...
...
@@ -877,8 +899,8 @@ class LLMEngine:
"""
return
self
.
scheduler
[
virtual_engine
].
has_unfinished_seqs
()
@
staticmethod
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
EmbeddingSequenceGroupOutput
],
)
->
None
:
...
...
@@ -1001,7 +1023,8 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
...
...
@@ -1022,8 +1045,8 @@ class LLMEngine:
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
# For multi-step, do
no
t create outputs each iteration
if
not
is_last_step
:
# For multi-step
without streaming
, do
n'
t create outputs each iteration
if
not
is_last_step
and
not
ctx
.
multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
...
...
@@ -1040,17 +1063,27 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
# For multi-step with streaming, create outputs each iteration
if
not
is_last_step
and
ctx
.
multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
if
self
.
process_request_outputs_callback
is
not
None
:
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
return
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
params
=
seq_group
.
sampling_params
if
params
is
not
None
and
params
.
output_kind
==
(
RequestOutputKind
.
DELTA
)
and
not
seq_group
.
is_finished
():
continue
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
...
...
@@ -1292,6 +1325,7 @@ class LLMEngine:
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
ctx
.
request_outputs
...
...
@@ -1608,7 +1642,7 @@ class LLMEngine:
def
start_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if
type
(
self
.
model_executor
)
==
GPUExecutor
:
if
type
(
self
.
model_executor
)
==
GPUExecutor
:
# noqa: E721
self
.
model_executor
.
start_profile
()
else
:
self
.
model_executor
.
_run_workers
(
"start_profile"
)
...
...
@@ -1616,7 +1650,7 @@ class LLMEngine:
def
stop_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if
type
(
self
.
model_executor
)
==
GPUExecutor
:
if
type
(
self
.
model_executor
)
==
GPUExecutor
:
# noqa: E721
self
.
model_executor
.
stop_profile
()
else
:
self
.
model_executor
.
_run_workers
(
"stop_profile"
)
...
...
@@ -1700,7 +1734,11 @@ class LLMEngine:
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
]):
if
self
.
is_encoder_decoder_model
():
if
self
.
model_config
.
is_multimodal_model
:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_ids
=
inputs
.
get
(
"prompt_token_ids"
)
elif
self
.
is_encoder_decoder_model
():
prompt_ids
=
inputs
.
get
(
"encoder_prompt_token_ids"
)
else
:
prompt_ids
=
inputs
.
get
(
"prompt_token_ids"
)
...
...
vllm/en
trypoints/openai/rpc
/__init__.py
→
vllm/en
gine/multiprocessing
/__init__.py
View file @
539aa992
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Mapping
,
Optional
,
Union
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
vllm
import
PoolingParams
from
vllm.inputs
import
PromptInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF
=
2000
IPC_INPUT_EXT
=
"_input_socket"
IPC_OUTPUT_EXT
=
"_output_socket"
IPC_HEALTH_EXT
=
"_health_socket"
IPC_DATA_EXT
=
"_data_socket"
# HWM is set to Infinity.
VLLM_RPC_ZMQ_HWM
=
0
class
MQEngineDeadError
(
RuntimeError
):
pass
@
dataclass
class
RPC
Generate
Request
:
class
RPC
Process
Request
:
inputs
:
PromptInputs
s
ampling
_p
arams
:
Samp
lingParams
params
:
Union
[
S
ampling
P
arams
,
Poo
lingParams
]
request_id
:
str
lora_request
:
Optional
[
LoRARequest
]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
@
dataclass
class
RPCError
:
request_id
:
Optional
[
str
]
is_engine_errored
:
bool
exception
:
BaseException
@
dataclass
class
RPCAbortRequest
:
request_id
:
str
class
RPC
Utility
Request
(
Enum
):
class
RPC
Startup
Request
(
Enum
):
IS_SERVER_READY
=
1
GET_MODEL_CONFIG
=
2
GET_DECODING_CONFIG
=
3
GET_PARALLEL_CONFIG
=
4
GET_SCHEDULER_CONFIG
=
5
GET_LORA_CONFIG
=
6
DO_LOG_STATS
=
7
IS_SERVER_HEALTHY
=
8
IS_TRACING_ENABLED
=
9
START_PROFILE
=
10
STOP_PROFILE
=
11
RPC_REQUEST_TYPE
=
Union
[
RPCGenerateRequest
,
RPCAbortRequest
,
RPCUtilityRequest
]
@
dataclass
class
RPCStartupResponse
:
tracing_enabled
:
bool
class
RPCUProfileRequest
(
Enum
):
START_PROFILE
=
1
STOP_PROFILE
=
2
RPC_REQUEST_T
=
Union
[
RPCProcessRequest
,
RPCAbortRequest
,
RPCStartupRequest
,
RPCUProfileRequest
]
REQUEST_OUTPUTS_T
=
Union
[
List
[
RequestOutput
],
RPCError
]
def
ENGINE_DEAD_ERROR
(
error
:
Optional
[
BaseException
]
=
None
)
->
MQEngineDeadError
:
if
error
is
None
:
return
MQEngineDeadError
(
"Engine loop is not running. Inspect the stacktrace to "
"find the original error"
)
return
MQEngineDeadError
(
"Engine loop is not running. Inspect the stacktrace to "
f
"find the original error:
{
repr
(
error
)
}
."
)
vllm/engine/multiprocessing/client.py
0 → 100644
View file @
539aa992
import
asyncio
import
copy
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Union
)
import
cloudpickle
import
zmq
import
zmq.asyncio
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
vllm
import
PoolingParams
from
vllm.config
import
DecodingConfig
,
EngineConfig
,
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.engine.multiprocessing
import
(
ENGINE_DEAD_ERROR
,
IPC_DATA_EXT
,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
RPC_REQUEST_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCProcessRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
logger
=
init_logger
(
__name__
)
class
MQClientClosedError
(
Exception
):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class
MQLLMEngineClient
:
"""A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol.
MQLLMEngine and MQLLMEngineClient are intended to run in separate
processes communicating via zeromq ipc sockets.
The entrypoint to MQLLMEngineClient is through the generate()
method. On generate() MQLLMEngine does three things:
- Creates an asyncio output queue
- Sends a RPCGenerateRequest to the MQLLMEngine via zmq
- Pulls RequestOutputs from its queue and yields them
MQLLMEngine runs two background loops:
- output_loop: the output loop pulls List[RequestOutput]
from the MQLLMEngine via zmq (each list is the output
of one engine_step in the LLMEngine). It then parses
the list and pushes individual request_outputs into
the corresponding output_queue such that they can be
consumed by the .generate() method.
- health_loop: the health loop queries the health socket
every N seconds, confirming the engine is healthy
"""
def
__init__
(
self
,
ipc_path
:
str
,
engine_config
:
EngineConfig
):
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
# Get the configs.
self
.
model_config
=
engine_config
.
model_config
self
.
decoding_config
=
engine_config
.
decoding_config
# Create the tokenizer group.
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
self
.
model_config
,
scheduler_config
=
engine_config
.
scheduler_config
,
parallel_config
=
engine_config
.
parallel_config
,
enable_lora
=
bool
(
engine_config
.
lora_config
),
)
# Send RPCGenerateRequest to the MQLLMEngine.
self
.
input_socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
PUSH
)
self
.
input_socket
.
connect
(
f
"
{
ipc_path
}{
IPC_INPUT_EXT
}
"
)
# Receive streams of RequestOutput from the MQLLMEngine.
self
.
output_socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
PULL
)
self
.
output_socket
.
connect
(
f
"
{
ipc_path
}{
IPC_OUTPUT_EXT
}
"
)
# IPC path for acking heartbeats.
self
.
heartbeat_socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
PULL
)
self
.
heartbeat_socket
.
connect
(
f
"
{
ipc_path
}{
IPC_HEALTH_EXT
}
"
)
# IPC path for the data socket.
self
.
data_ipc_path
=
f
"
{
ipc_path
}{
IPC_DATA_EXT
}
"
# Stream for each individual request.
self
.
output_queues
:
Dict
[
str
,
asyncio
.
Queue
]
=
{}
self
.
output_loop
=
asyncio
.
create_task
(
self
.
run_output_handler_loop
())
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self
.
health_loop
:
Optional
[
asyncio
.
Task
]
=
None
@
staticmethod
def
is_unsupported_config
(
engine_args
:
AsyncEngineArgs
):
# Pipeline parallel not yet supported
return
engine_args
.
pipeline_parallel_size
>
1
@
contextmanager
def
get_data_socket
(
self
)
->
Iterator
[
Socket
]:
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
try
:
socket
.
connect
(
self
.
data_ipc_path
)
yield
socket
finally
:
socket
.
close
(
linger
=
0
)
async
def
run_heartbeat_loop
(
self
,
timeout
:
int
):
"""Background loop that continually listens to the RPCServer for
heartbeats.
"""
try
:
while
True
:
if
await
self
.
heartbeat_socket
.
poll
(
timeout
=
timeout
)
==
0
:
# No heartbeat was received. Set error and exit the loop
self
.
_set_errored
(
TimeoutError
(
"No heartbeat received "
"from MQLLMEngine"
))
logger
.
debug
(
"Shutting down MQLLMEngineClient check "
"health loop due to timeout"
)
break
else
:
# Heartbeat received- check the message
await
self
.
_check_success
(
error_message
=
"Heartbeat failed."
,
socket
=
self
.
heartbeat_socket
)
logger
.
debug
(
"Heartbeat successful."
)
except
asyncio
.
CancelledError
:
logger
.
debug
(
"Shutting down MQLLMEngineClient check health loop."
)
except
Exception
as
e
:
self
.
_set_errored
(
e
)
async
def
run_output_handler_loop
(
self
):
"""Get RequestOutputs from Engine and stream to Request Queues"""
try
:
while
True
:
# Poll, checking for ENGINE_DEAD
while
await
self
.
output_socket
.
poll
(
timeout
=
VLLM_RPC_TIMEOUT
)
==
0
:
logger
.
debug
(
"Waiting for output from MQLLMEngine."
)
# If errored, alert all running requests.
if
self
.
errored
:
for
queue_j
in
tuple
(
self
.
output_queues
.
values
()):
queue_j
.
put_nowait
(
ENGINE_DEAD_ERROR
(
self
.
_errored_with
))
return
message
:
Frame
=
await
self
.
output_socket
.
recv
(
copy
=
False
)
request_outputs
=
pickle
.
loads
(
message
.
buffer
)
is_error
=
isinstance
(
request_outputs
,
(
BaseException
,
RPCError
))
if
is_error
:
if
isinstance
(
request_outputs
,
RPCError
):
rpc_error
:
RPCError
=
request_outputs
request_id
=
rpc_error
.
request_id
exception
=
rpc_error
.
exception
is_engine_errored
=
rpc_error
.
is_engine_errored
else
:
# MPLLMEngine should always return an RPCError to
# the output_socket when an issue arises.
# If we are here, we are in a bad state and
# should shut down the server.
error
:
BaseException
=
request_outputs
logger
.
error
(
"Received Exception %s rather than RPCError from "
"MPLLMEngine. This should never happen."
,
error
)
request_id
=
None
exception
=
error
is_engine_errored
=
True
# Set to error state only on engine critical error
# (and record only the first one)
if
is_engine_errored
and
not
self
.
_errored_with
:
self
.
_errored_with
=
exception
if
request_id
is
None
:
for
queue_i
in
tuple
(
self
.
output_queues
.
values
()):
queue_i
.
put_nowait
(
exception
)
else
:
queue
=
self
.
output_queues
.
get
(
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
exception
)
else
:
# Put each output into the appropriate steam.
for
request_output
in
request_outputs
:
queue
=
self
.
output_queues
.
get
(
request_output
.
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
request_output
)
except
asyncio
.
CancelledError
:
logger
.
debug
(
"Shutting down MQLLMEngineClient output handler."
)
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
with
self
.
get_data_socket
()
as
socket
:
# Wait until server is ready.
response
=
await
self
.
_wait_for_server_rpc
(
socket
)
self
.
tracing_flag
=
response
.
tracing_enabled
# Start health_loop.
self
.
health_loop
=
asyncio
.
create_task
(
self
.
run_heartbeat_loop
(
timeout
=
VLLM_RPC_TIMEOUT
))
def
close
(
self
):
"""Destroy the ZeroMQ Context."""
# Close all sockets and terminate the context.
self
.
context
.
destroy
(
linger
=
0
)
# Cancel background tasks.
if
self
.
health_loop
is
not
None
:
self
.
health_loop
.
cancel
()
self
.
output_loop
.
cancel
()
def
_set_errored
(
self
,
e
:
BaseException
):
logger
.
exception
(
repr
(
e
))
if
self
.
_errored_with
is
None
:
self
.
_errored_with
=
e
@
staticmethod
async
def
_send_get_data_rpc_request
(
request
:
RPCStartupRequest
,
expected_type
:
Any
,
error_message
:
str
,
socket
:
Socket
)
->
Any
:
"""Send an RPC request that is expecting data back."""
# Ping RPCServer with a request.
await
socket
.
send_multipart
((
pickle
.
dumps
(
request
),
),
copy
=
False
)
# Make sure the server responds in time.
if
await
socket
.
poll
(
timeout
=
VLLM_RPC_TIMEOUT
)
==
0
:
raise
TimeoutError
(
"RPCServer didn't reply within "
f
"
{
VLLM_RPC_TIMEOUT
}
ms"
)
# Await the data from the Server.
frame
=
await
socket
.
recv
(
copy
=
False
)
data
=
pickle
.
loads
(
frame
.
buffer
)
if
isinstance
(
data
,
BaseException
):
raise
data
elif
not
isinstance
(
data
,
expected_type
):
raise
ValueError
(
error_message
)
return
data
@
staticmethod
async
def
_send_one_way_rpc_request
(
request
:
RPC_REQUEST_T
,
socket
:
Socket
):
"""Send one-way RPC request to trigger an action."""
if
socket
.
closed
:
raise
MQClientClosedError
()
await
socket
.
send_multipart
((
pickle
.
dumps
(
request
),
))
async
def
_await_ack
(
self
,
error_message
:
str
,
socket
:
Socket
):
"""Await acknowledgement that a request succeeded."""
if
socket
.
closed
:
raise
MQClientClosedError
()
if
await
socket
.
poll
(
timeout
=
VLLM_RPC_TIMEOUT
)
==
0
:
raise
TimeoutError
(
"MQLLMEngine didn't reply within "
f
"
{
VLLM_RPC_TIMEOUT
}
ms"
)
await
self
.
_check_success
(
error_message
,
socket
)
@
staticmethod
async
def
_check_success
(
error_message
:
str
,
socket
:
Socket
):
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
if
socket
.
closed
:
raise
MQClientClosedError
()
frame
=
await
socket
.
recv
(
copy
=
False
)
response
=
pickle
.
loads
(
frame
.
buffer
)
# Raise error if unsuccessful
if
isinstance
(
response
,
BaseException
):
raise
response
elif
(
not
isinstance
(
response
,
str
)
or
response
!=
VLLM_RPC_SUCCESS_STR
):
raise
ValueError
(
error_message
)
async
def
get_tokenizer
(
self
,
lora_request
:
LoRARequest
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
return
self
.
decoding_config
async
def
get_model_config
(
self
)
->
ModelConfig
:
return
self
.
model_config
async
def
is_tracing_enabled
(
self
)
->
bool
:
return
self
.
tracing_flag
async
def
_wait_for_server_rpc
(
self
,
socket
:
Socket
)
->
RPCStartupResponse
:
"""Wait for the RPCServer to start up."""
return
await
self
.
_send_get_data_rpc_request
(
request
=
RPCStartupRequest
.
IS_SERVER_READY
,
expected_type
=
RPCStartupResponse
,
error_message
=
"Unable to start RPC Server"
,
socket
=
socket
)
async
def
abort
(
self
,
request_id
:
str
):
"""Send an ABORT_REQUEST signal to the RPC Server"""
with
suppress
(
MQClientClosedError
):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCAbortRequest
(
request_id
),
socket
=
self
.
input_socket
)
async
def
do_log_stats
(
self
):
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
pass
async
def
check_health
(
self
):
"""
The check health loop probes the health status of the
Engine's health every N seconds and sets _errored_with
if the engine is unhealthy.
"""
if
self
.
_errored_with
is
not
None
:
raise
self
.
_errored_with
@
property
def
is_running
(
self
)
->
bool
:
return
not
self
.
errored
@
property
def
is_stopped
(
self
)
->
bool
:
return
self
.
errored
@
property
def
errored
(
self
)
->
bool
:
return
self
.
_errored_with
is
not
None
@
property
def
dead_error
(
self
)
->
BaseException
:
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
def
generate
(
self
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return
self
.
_process_request
(
inputs
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
prompt_adapter_request
)
def
encode
(
self
,
inputs
:
PromptInputs
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return
self
.
_process_request
(
inputs
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
)
async
def
_process_request
(
self
,
inputs
:
PromptInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
Union
[
AsyncGenerator
[
RequestOutput
,
None
],
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.
if
self
.
_errored_with
is
not
None
:
raise
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
# 1) Create output queue for this requests.
queue
:
asyncio
.
Queue
[
Union
[
RequestOutput
,
BaseException
]]
=
asyncio
.
Queue
()
self
.
output_queues
[
request_id
]
=
queue
try
:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if
isinstance
(
params
,
SamplingParams
)
and
params
.
logits_processors
:
# Defensive shallow copy
params
=
copy
.
copy
(
params
)
logits_processors
=
params
.
logits_processors
params
.
logits_processors
=
None
lp_bytes
=
cloudpickle
.
dumps
(
logits_processors
)
else
:
lp_bytes
=
None
request_bytes
=
pickle
.
dumps
(
RPCProcessRequest
(
inputs
=
inputs
,
params
=
params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts
=
(
request_bytes
,
lp_bytes
)
if
lp_bytes
else
(
request_bytes
,
)
await
self
.
input_socket
.
send_multipart
(
parts
,
copy
=
False
)
# 4) Stream the RequestOutputs from the output queue. Note
# that the output_loop pushes RequestOutput objects to this
# queue after pulling them from the zmq socket.
finished
=
False
try
:
while
not
finished
:
request_output
=
await
queue
.
get
()
if
isinstance
(
request_output
,
BaseException
):
raise
request_output
finished
=
request_output
.
finished
yield
request_output
finally
:
# Request was canceled by the client.
if
not
finished
and
not
self
.
errored
:
await
self
.
abort
(
request_id
)
finally
:
self
.
output_queues
.
pop
(
request_id
)
async
def
start_profile
(
self
)
->
None
:
"""Start profiling the engine"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUProfileRequest
.
START_PROFILE
,
socket
=
self
.
input_socket
)
async
def
stop_profile
(
self
)
->
None
:
"""Stop profiling the engine"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUProfileRequest
.
STOP_PROFILE
,
socket
=
self
.
input_socket
)
vllm/engine/multiprocessing/engine.py
0 → 100644
View file @
539aa992
import
pickle
import
signal
import
threading
import
time
from
contextlib
import
contextmanager
from
typing
import
Iterator
,
List
,
Optional
,
Union
import
cloudpickle
import
zmq
from
vllm
import
AsyncEngineArgs
,
LLMEngine
,
SamplingParams
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.engine.multiprocessing
import
(
ENGINE_DEAD_ERROR
,
IPC_DATA_EXT
,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
REQUEST_OUTPUTS_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCProcessRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.usage.usage_lib
import
UsageContext
CONFIG_TYPE
=
Union
[
ModelConfig
,
DecodingConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
]
logger
=
init_logger
(
__name__
)
POLLING_TIMEOUT_MS
=
10000
HEALTHY_RESPONSE
=
(
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
)
class
MQLLMEngine
:
"""A multiprocessing wrapper for :class:`LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to enable use
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The :class:`LLMEngine` generate or encode process is kicked off when a new
RPCProcessRequest is received by the input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
:class:`LLMEngine.step()`, and sends the RequestOutputs back over
the output_socket.
If use_async_sockets is set, the logic associated with reading new
requests from the socket and sending data to the socket is passed
as a callback to the llm_engine, which calls the logic asynchronously
such that the IPC can be overlapped with the GPU.
Args:
ipc_path: Base path for zeromq interprocess messaging
use_async_sockets: Whether to make send/recv async with GPU
log_requests: Whether to log the requests.
*args: Arguments for :class:`LLMEngine`.
**kwargs: Arguments for :class:`LLMEngine`.
"""
def
__init__
(
self
,
ipc_path
:
str
,
use_async_sockets
:
bool
,
*
args
,
log_requests
:
bool
=
True
,
**
kwargs
)
->
None
:
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs
=
True
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
,
use_cached_outputs
=
use_cached_outputs
)
self
.
log_requests
=
log_requests
self
.
use_async_sockets
=
use_async_sockets
if
self
.
use_async_sockets
:
self
.
engine
.
process_request_outputs_callback
=
\
self
.
_async_socket_engine_callback
self
.
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
# Receive input from the client.
self
.
input_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PULL
)
self
.
input_socket
.
bind
(
f
"
{
ipc_path
}{
IPC_INPUT_EXT
}
"
)
# Send output stream back to client.
self
.
output_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PUSH
)
self
.
output_socket
.
bind
(
f
"
{
ipc_path
}{
IPC_OUTPUT_EXT
}
"
)
# Send heartbeats back to client.
self
.
heartbeat_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PUSH
)
self
.
heartbeat_socket
.
bind
(
f
"
{
ipc_path
}{
IPC_HEALTH_EXT
}
"
)
# IPC path for the data socket.
self
.
data_ipc_path
=
f
"
{
ipc_path
}{
IPC_DATA_EXT
}
"
# Error state.
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
# Heartbeat thread
self
.
heartbeat_thread
=
threading
.
Thread
(
target
=
self
.
_heartbeat_loop
,
daemon
=
True
)
self
.
_heartbeat_stop_event
=
threading
.
Event
()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self
.
heartbeat_interval_seconds
=
VLLM_RPC_TIMEOUT
/
5000.0
self
.
_last_alive_time
=
time
.
time
()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self
.
last_alive_threshold
=
VLLM_RPC_TIMEOUT
*
3.0
/
1000.0
@
property
def
dead_error
(
self
)
->
BaseException
:
if
self
.
_errored_with
is
not
None
:
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
else
:
return
ENGINE_DEAD_ERROR
()
@
classmethod
def
from_engine_args
(
cls
,
engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
ipc_path
:
str
):
"""Creates an MQLLMEngine from the engine arguments."""
engine_config
=
engine_args
.
create_engine_config
()
executor_class
=
LLMEngine
.
_get_executor_cls
(
engine_config
)
return
cls
(
ipc_path
=
ipc_path
,
use_async_sockets
=
engine_config
.
model_config
.
use_async_output_proc
,
**
engine_config
.
to_dict
(),
executor_class
=
executor_class
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_stats
=
not
engine_args
.
disable_log_stats
,
usage_context
=
usage_context
)
def
start
(
self
):
try
:
try
:
logger
.
debug
(
"Starting Startup Loop."
)
self
.
run_startup_loop
()
logger
.
debug
(
"Starting heartbeat thread"
)
self
.
heartbeat_thread
.
start
()
logger
.
debug
(
"Starting Engine Loop."
)
self
.
run_engine_loop
()
except
Exception
as
e
:
logger
.
exception
(
repr
(
e
))
except
KeyboardInterrupt
:
logger
.
debug
(
"Shutting down MQLLMEngine."
)
finally
:
logger
.
debug
(
"MQLLMEngine is shut down."
)
self
.
cleanup
()
def
cleanup
(
self
):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self
.
_heartbeat_stop_event
.
set
()
self
.
ctx
.
destroy
(
linger
=
0
)
del
self
.
engine
@
contextmanager
def
make_data_socket
(
self
)
->
Iterator
[
zmq
.
Socket
]:
# type: ignore[name-defined]
socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
ROUTER
)
try
:
socket
.
bind
(
self
.
data_ipc_path
)
yield
socket
finally
:
socket
.
close
(
linger
=
0
)
def
run_startup_loop
(
self
)
->
None
:
"""Startup loop for sending data from Engine -> Client."""
with
self
.
make_data_socket
()
as
socket
:
response
:
Union
[
RPCStartupResponse
,
BaseException
]
try
:
identity
,
message
=
socket
.
recv_multipart
(
copy
=
False
)
request
:
RPCStartupRequest
=
pickle
.
loads
(
message
.
buffer
)
# Handle the query from the Client.
if
request
==
RPCStartupRequest
.
IS_SERVER_READY
:
tracing_enabled
=
self
.
engine
.
is_tracing_enabled
()
response
=
RPCStartupResponse
(
tracing_enabled
=
tracing_enabled
)
except
Exception
as
e
:
response
=
e
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
response
)),
copy
=
False
)
def
run_engine_loop
(
self
):
"""Core busy loop of the LLMEngine."""
while
True
:
self
.
_alive
()
if
not
self
.
engine
.
has_unfinished_requests
():
# Poll until there is work to do.
while
self
.
input_socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
self
.
_alive
()
self
.
engine
.
do_log_stats
()
logger
.
debug
(
"Waiting for new requests in engine loop."
)
# Handle any input from the client.
self
.
handle_new_input
()
# Engine step.
request_outputs
=
self
.
engine_step
()
# Send request outputs (if async, done in engine_step callback).
if
not
self
.
use_async_sockets
:
self
.
_send_outputs
(
request_outputs
)
def
engine_step
(
self
)
->
List
[
RequestOutput
]:
"""Engine step wrapper with error handling."""
try
:
return
self
.
engine
.
step
()
except
SystemExit
:
raise
except
BaseException
as
e
:
self
.
_set_errored
(
e
)
rpc_err
=
RPCError
(
request_id
=
None
,
is_engine_errored
=
True
,
exception
=
e
)
self
.
_send_outputs
(
rpc_err
)
raise
e
def
handle_new_input
(
self
):
"""Handle new input from the socket"""
try
:
while
self
.
input_socket
.
poll
(
timeout
=
0
)
!=
0
:
frames
=
self
.
input_socket
.
recv_multipart
(
copy
=
False
)
request
=
pickle
.
loads
(
frames
[
0
].
buffer
)
if
isinstance
(
request
,
RPCProcessRequest
):
if
len
(
frames
)
>
1
:
# Use cloudpickle for logits processors
assert
isinstance
(
request
.
params
,
SamplingParams
)
lprocs
=
cloudpickle
.
loads
(
frames
[
1
].
buffer
)
request
.
params
.
logits_processors
=
lprocs
self
.
_handle_process_request
(
request
)
elif
isinstance
(
request
,
RPCAbortRequest
):
self
.
_handle_abort_request
(
request
)
elif
isinstance
(
request
,
RPCUProfileRequest
):
if
request
==
RPCUProfileRequest
.
START_PROFILE
:
self
.
start_profile
()
else
:
self
.
stop_profile
()
else
:
raise
ValueError
(
"Unknown RPCRequest Type: "
f
"
{
type
(
request
)
}
"
)
except
Exception
as
e
:
self
.
_set_errored
(
e
)
self
.
_send_unhealthy
(
e
)
raise
e
def
_handle_process_request
(
self
,
request
:
RPCProcessRequest
):
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
request_id
=
request
.
request_id
if
self
.
_errored_with
is
not
None
:
rpc_err
=
RPCError
(
request_id
=
request_id
,
is_engine_errored
=
True
,
exception
=
ENGINE_DEAD_ERROR
(
self
.
_errored_with
))
self
.
_send_outputs
(
rpc_err
)
try
:
self
.
engine
.
add_request
(
request_id
=
request_id
,
inputs
=
request
.
inputs
,
params
=
request
.
params
,
lora_request
=
request
.
lora_request
,
trace_headers
=
request
.
trace_headers
,
prompt_adapter_request
=
request
.
prompt_adapter_request
)
if
self
.
log_requests
:
logger
.
info
(
"Added request %s."
,
request
.
request_id
)
except
Exception
as
e
:
# We do not set self._errored = True here, since the error
# is due to an issue adding this request to the engine,
# rather than an issue with the engine itself.
is_errored
=
self
.
_errored_with
is
not
None
rpc_err
=
RPCError
(
request_id
=
request_id
,
is_engine_errored
=
is_errored
,
exception
=
e
)
self
.
_send_outputs
(
rpc_err
)
# Remove request from the engine.
self
.
engine
.
abort_request
(
request_id
)
def
_handle_abort_request
(
self
,
request
:
RPCAbortRequest
):
self
.
engine
.
abort_request
(
request
.
request_id
)
if
self
.
log_requests
:
logger
.
info
(
"Aborted request %s."
,
request
.
request_id
)
def
_heartbeat_loop
(
self
):
while
not
self
.
_heartbeat_stop_event
.
wait
(
timeout
=
self
.
heartbeat_interval_seconds
):
# Loops until the stop event is set
self
.
_heartbeat
()
logger
.
debug
(
"Exiting MQLLMEngine heartbeat thread"
)
def
_heartbeat
(
self
):
# Send unhealthy if engine has already errored
if
self
.
_errored_with
is
not
None
:
self
.
_send_unhealthy
(
self
.
_errored_with
)
# Check for life of the main loop
elif
time
.
time
()
-
self
.
_last_alive_time
>
self
.
last_alive_threshold
:
self
.
_send_unhealthy
(
RuntimeError
(
"Engine loop has died"
))
else
:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try
:
self
.
engine
.
check_health
()
self
.
_send_healthy
()
except
Exception
as
e
:
self
.
_set_errored
(
e
)
self
.
_send_unhealthy
(
e
)
def
_send_outputs
(
self
,
outputs
:
REQUEST_OUTPUTS_T
):
"""Send List of RequestOutput to RPCClient."""
if
outputs
:
output_bytes
=
pickle
.
dumps
(
outputs
)
self
.
output_socket
.
send_multipart
((
output_bytes
,
),
copy
=
False
)
def
_send_healthy
(
self
):
"""Send HEALTHY message to RPCClient."""
if
not
self
.
heartbeat_socket
.
closed
:
self
.
heartbeat_socket
.
send_multipart
(
HEALTHY_RESPONSE
,
copy
=
False
)
def
_send_unhealthy
(
self
,
error
:
BaseException
):
"""Send UNHEALTHY message to RPCClient."""
if
not
self
.
heartbeat_socket
.
closed
:
error_bytes
=
pickle
.
dumps
(
error
)
self
.
heartbeat_socket
.
send_multipart
((
error_bytes
,
),
copy
=
False
)
def
_async_socket_engine_callback
(
self
,
request_outputs
:
REQUEST_OUTPUTS_T
):
"""Callback used by engine to make socket handling async with GPU."""
self
.
_send_outputs
(
request_outputs
)
self
.
handle_new_input
()
def
_set_errored
(
self
,
e
:
BaseException
):
"""Log and set errored status if this is the first issue."""
if
self
.
_errored_with
is
None
:
self
.
_errored_with
=
e
def
_alive
(
self
):
self
.
_last_alive_time
=
time
.
time
()
def
start_profile
(
self
)
->
None
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
self
.
engine
.
model_executor
.
start_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"start_profile"
)
def
stop_profile
(
self
)
->
None
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
self
.
engine
.
model_executor
.
stop_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"stop_profile"
)
def
run_mp_engine
(
engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
ipc_path
:
str
):
def
signal_handler
(
*
_
)
->
None
:
# Interrupt server on sigterm
raise
KeyboardInterrupt
(
"MQLLMEngine terminated"
)
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
engine
=
MQLLMEngine
.
from_engine_args
(
engine_args
=
engine_args
,
usage_context
=
usage_context
,
ipc_path
=
ipc_path
)
engine
.
start
()
vllm/engine/output_processor/multi_step.py
View file @
539aa992
...
...
@@ -9,8 +9,8 @@ from vllm.engine.output_processor.single_step import (
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
Counter
...
...
@@ -110,10 +110,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# we can take the first sample.
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
#
-1 means the output token is not
valid (eg. due to spec decode
#
entries in sample tokens may be in
valid (eg. due to spec decode
# rejecting tokens).
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
-
1
sample
for
sample
in
samples
if
sample
.
output_token
!=
VLLM_INVALID_TOKEN_ID
]
assert
valid_samples
...
...
vllm/engine/protocol.py
View file @
539aa992
...
...
@@ -14,8 +14,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
@
runtime_checkable
class
Async
EngineClient
(
Protocol
):
"""Protocol class for Clients to
AsyncLLM
Engine"""
class
EngineClient
(
Protocol
):
"""Protocol class for Clients to Engine"""
@
property
def
is_running
(
self
)
->
bool
:
...
...
@@ -30,8 +30,8 @@ class AsyncEngineClient(Protocol):
...
@
property
def
limit_concurrency
(
self
)
->
Optional
[
int
]
:
"""Maximum number of concurrently running requests."""
def
dead_error
(
self
)
->
BaseException
:
...
def
generate
(
self
,
...
...
vllm/entrypoints/api_server.py
View file @
539aa992
...
...
@@ -121,7 +121,6 @@ async def run_server(args: Namespace,
shutdown_task
=
await
serve_http
(
app
,
engine
=
engine
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
...
...
vllm/entrypoints/chat_utils.py
View file @
539aa992
...
...
@@ -159,6 +159,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config
.
image_token_index
)
if
model_type
in
(
"chameleon"
,
"internvl_chat"
):
return
"<image>"
if
model_type
==
"mllama"
:
return
"<|image|>"
if
model_type
==
"qwen2_vl"
:
return
"<|vision_start|><|image_pad|><|vision_end|>"
...
...
@@ -358,6 +360,7 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser
=
partial
(
cast
,
ChatCompletionContentPartImageParam
)
_AudioParser
=
partial
(
cast
,
ChatCompletionContentPartAudioParam
)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
MODEL_KEEP_MULTI_MODAL_CONTENT
=
{
'mllama'
}
def
_parse_chat_message_content_parts
(
...
...
@@ -368,7 +371,11 @@ def _parse_chat_message_content_parts(
texts
:
List
[
str
]
=
[]
mm_parser
=
mm_tracker
.
create_parser
()
keep_multimodal_content
=
\
mm_tracker
.
_model_config
.
hf_config
.
model_type
in
\
MODEL_KEEP_MULTI_MODAL_CONTENT
has_image
=
False
for
part
in
parts
:
part_type
=
part
[
"type"
]
if
part_type
==
"text"
:
...
...
@@ -383,6 +390,7 @@ def _parse_chat_message_content_parts(
"will be ignored."
)
mm_parser
.
parse_image
(
image_url
[
"url"
])
has_image
=
True
elif
part_type
==
"audio_url"
:
audio_url
=
_AudioParser
(
part
)[
"audio_url"
]
...
...
@@ -394,12 +402,20 @@ def _parse_chat_message_content_parts(
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
text_prompt
=
"
\n
"
.
join
(
texts
)
mm_placeholder_counts
=
mm_parser
.
mm_placeholder_counts
()
if
mm_placeholder_counts
:
text_prompt
=
_get_full_multimodal_text_prompt
(
mm_placeholder_counts
,
text_prompt
)
return
[
ConversationMessage
(
role
=
role
,
content
=
text_prompt
)]
if
keep_multimodal_content
:
text_prompt
=
"
\n
"
.
join
(
texts
)
role_content
=
[{
'type'
:
'text'
,
'text'
:
text_prompt
}]
if
has_image
:
role_content
=
[{
'type'
:
'image'
}]
+
role_content
return
[
ConversationMessage
(
role
=
role
,
content
=
role_content
)]
# type: ignore
else
:
mm_placeholder_counts
=
mm_parser
.
mm_placeholder_counts
()
if
mm_placeholder_counts
:
text_prompt
=
_get_full_multimodal_text_prompt
(
mm_placeholder_counts
,
text_prompt
)
return
[
ConversationMessage
(
role
=
role
,
content
=
text_prompt
)]
# No need to validate using Pydantic again
...
...
vllm/entrypoints/launcher.py
View file @
539aa992
...
...
@@ -4,19 +4,18 @@ from http import HTTPStatus
from
typing
import
Any
import
uvicorn
from
fastapi
import
FastAPI
,
Response
from
fastapi
import
FastAPI
,
Request
,
Response
from
vllm
import
envs
from
vllm.engine.async_llm_engine
import
AsyncEngineDeadError
from
vllm.engine.
protocol
import
Async
Engine
Client
from
vllm.engine.
multiprocessing
import
MQ
Engine
DeadError
from
vllm.logger
import
init_logger
from
vllm.utils
import
find_process_using_port
logger
=
init_logger
(
__name__
)
async
def
serve_http
(
app
:
FastAPI
,
engine
:
AsyncEngineClient
,
**
uvicorn_kwargs
:
Any
):
async
def
serve_http
(
app
:
FastAPI
,
**
uvicorn_kwargs
:
Any
):
logger
.
info
(
"Available routes are:"
)
for
route
in
app
.
routes
:
methods
=
getattr
(
route
,
"methods"
,
None
)
...
...
@@ -27,18 +26,9 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
logger
.
info
(
"Route: %s, Methods: %s"
,
path
,
', '
.
join
(
methods
))
# Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if
engine
.
limit_concurrency
is
not
None
:
logger
.
info
(
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with "
"--disable-frontend-multiprocessing"
,
engine
.
limit_concurrency
)
uvicorn_kwargs
[
"limit_concurrency"
]
=
engine
.
limit_concurrency
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
server
=
uvicorn
.
Server
(
config
)
_add_shutdown_handlers
(
app
,
server
,
engine
)
_add_shutdown_handlers
(
app
,
server
)
loop
=
asyncio
.
get_running_loop
()
...
...
@@ -64,19 +54,19 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
logger
.
debug
(
"port %s is used by process %s launched with command:
\n
%s"
,
port
,
process
,
" "
.
join
(
process
.
cmdline
()))
logger
.
info
(
"
Gracefully stopping http
server"
)
logger
.
info
(
"
Shutting down FastAPI HTTP
server
.
"
)
return
server
.
shutdown
()
def
_add_shutdown_handlers
(
app
:
FastAPI
,
server
:
uvicorn
.
Server
,
engine
:
AsyncEngineClient
)
->
None
:
def
_add_shutdown_handlers
(
app
:
FastAPI
,
server
:
uvicorn
.
Server
)
->
None
:
"""Adds handlers for fatal errors that should crash the server"""
@
app
.
exception_handler
(
RuntimeError
)
async
def
runtime_error_handler
(
_
,
__
):
async
def
runtime_error_handler
(
request
:
Request
,
__
):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine
=
request
.
app
.
state
.
engine_client
if
(
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
and
engine
.
errored
and
not
engine
.
is_running
):
logger
.
fatal
(
"AsyncLLMEngine has failed, terminating server "
...
...
@@ -91,7 +81,7 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
@
app
.
exception_handler
(
AsyncEngineDeadError
)
async
def
engine_dead_handler
(
_
,
__
):
async
def
async_
engine_dead_handler
(
_
,
__
):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
:
...
...
@@ -100,3 +90,14 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
server
.
should_exit
=
True
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
@
app
.
exception_handler
(
MQEngineDeadError
)
async
def
mq_engine_dead_handler
(
_
,
__
):
"""Kill the server if the mq engine is already dead. It will
not handle any further requests."""
if
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
:
logger
.
fatal
(
"MQLLMEngine is already dead, terminating server "
"process"
)
server
.
should_exit
=
True
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
vllm/entrypoints/llm.py
View file @
539aa992
import
itertools
from
contextlib
import
contextmanager
from
typing
import
ClassVar
,
List
,
Optional
,
Sequence
,
Union
,
cast
,
overload
from
dataclasses
import
dataclass
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
,
cast
,
overload
)
from
tqdm
import
tqdm
...
...
@@ -29,6 +32,37 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
logger
=
init_logger
(
__name__
)
@
dataclass
class
BeamSearchSequence
:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens
:
List
[
int
]
cum_logprob
:
float
=
0.0
text
:
Optional
[
str
]
=
None
@
dataclass
class
BeamSearchOutput
:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences
:
List
[
BeamSearchSequence
]
class
BeamSearchInstance
:
def
__init__
(
self
,
prompt_tokens
:
List
[
int
]):
self
.
beams
:
List
[
BeamSearchSequence
]
=
[
BeamSearchSequence
(
tokens
=
prompt_tokens
)
]
self
.
completed
:
List
[
BeamSearchSequence
]
=
[]
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
...
...
@@ -88,7 +122,9 @@ class LLM:
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_custom_all_reduce: See ParallelConfig
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
...
...
@@ -131,15 +167,14 @@ class LLM:
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
,
)
->
None
:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False for decoder-only models and True
for encoder/decoder models, since encoder/decoder models
do not currently support CUDAGraph.
it defaults to False.
'''
if
"disable_log_stats"
not
in
kwargs
:
...
...
@@ -173,6 +208,7 @@ class LLM:
max_seq_len_to_capture
=
max_seq_len_to_capture
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_async_output_proc
=
disable_async_output_proc
,
mm_processor_kwargs
=
mm_processor_kwargs
,
**
kwargs
,
)
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
...
...
@@ -284,7 +320,8 @@ class LLM:
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
GuidedDecodingRequest
]]
=
None
,
priority
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
...
...
@@ -303,6 +340,8 @@ class LLM:
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled.
Returns:
A list of ``RequestOutput`` objects containing the
...
...
@@ -343,20 +382,122 @@ class LLM:
params
=
sampling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
guided_options
=
guided_options_request
)
guided_options
=
guided_options_request
,
priority
=
priority
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
def
beam_search
(
self
,
prompts
:
List
[
Union
[
str
,
List
[
int
]]],
beam_width
:
int
,
max_tokens
:
int
,
ignore_eos
:
bool
=
False
,
)
->
List
[
BeamSearchOutput
]:
"""
Generate sequences using beam search.
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
tokenizer
=
self
.
get_tokenizer
()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
0.0
)
instances
:
List
[
BeamSearchInstance
]
=
[]
for
prompt
in
prompts
:
prompt_tokens
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
instances
.
append
(
BeamSearchInstance
(
prompt_tokens
))
for
_
in
range
(
max_tokens
):
all_beams
:
List
[
BeamSearchSequence
]
=
list
(
sum
((
instance
.
beams
for
instance
in
instances
),
[]))
pos
=
[
0
]
+
list
(
itertools
.
accumulate
(
len
(
instance
.
beams
)
for
instance
in
instances
))
instance_start_and_end
:
List
[
Tuple
[
int
,
int
]]
=
list
(
zip
(
pos
[:
-
1
],
pos
[
1
:]))
if
len
(
all_beams
)
==
0
:
break
prompts_batch
=
[
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
)
for
beam
in
all_beams
]
# only runs for one step
# we don't need to use tqdm here
output
=
self
.
generate
(
prompts_batch
,
sampling_params
=
beam_search_params
,
use_tqdm
=
False
)
for
(
start
,
end
),
instance
in
zip
(
instance_start_and_end
,
instances
):
instance_new_beams
=
[]
for
i
in
range
(
start
,
end
):
current_beam
=
all_beams
[
i
]
result
=
output
[
i
]
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the max-model-len
# or abortion. we don't need to add it to the new beams.
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
if
token_id
==
tokenizer
.
eos_token_id
and
\
not
ignore_eos
:
instance
.
completed
.
append
(
new_beam
)
else
:
instance_new_beams
.
append
(
new_beam
)
sorted_beams
=
sorted
(
instance_new_beams
,
key
=
lambda
x
:
x
.
cum_logprob
,
reverse
=
True
)
instance
.
beams
=
sorted_beams
[:
beam_width
]
outputs
=
[]
for
instance
in
instances
:
instance
.
completed
.
extend
(
instance
.
beams
)
sorted_completed
=
sorted
(
instance
.
completed
,
key
=
lambda
x
:
x
.
cum_logprob
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
)
outputs
.
append
(
BeamSearchOutput
(
sequences
=
best_beams
))
return
outputs
def
chat
(
self
,
messages
:
List
[
ChatCompletionMessageParam
],
messages
:
Union
[
List
[
ChatCompletionMessageParam
],
List
[
List
[
ChatCompletionMessageParam
]]],
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
add_generation_prompt
:
bool
=
True
,
tools
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
)
->
List
[
RequestOutput
]:
"""
Generate responses for a chat conversation.
...
...
@@ -369,8 +510,9 @@ class LLM:
to the OpenAI API.
Args:
messages: A single conversation represented as a list of messages.
Each message is a dictionary with 'role' and 'content' keys.
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
...
...
@@ -387,40 +529,56 @@ class LLM:
A list of ``RequestOutput`` objects containing the generated
responses in the same order as the input messages.
"""
list_of_messages
:
List
[
List
[
ChatCompletionMessageParam
]]
tokenizer
=
self
.
get_tokenizer
()
model_config
=
self
.
llm_engine
.
get_model_config
()
conversation
,
mm_data
=
parse_chat_messages
(
messages
,
model_config
,
tokenizer
)
prompt
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
prompt
=
apply_mistral_chat_template
(
tokenizer
,
messages
=
messages
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
)
# Handle multi and single conversations
if
is_list_of
(
messages
,
list
):
# messages is List[List[...]]
list_of_messages
=
messages
else
:
prompt
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
)
# messages is List[...]
list_of_messages
=
[
messages
]
prompts
:
List
[
Union
[
TokensPrompt
,
TextPrompt
]]
=
[]
for
msgs
in
list_of_messages
:
tokenizer
=
self
.
get_tokenizer
()
model_config
=
self
.
llm_engine
.
get_model_config
()
conversation
,
mm_data
=
parse_chat_messages
(
msgs
,
model_config
,
tokenizer
)
prompt_data
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
prompt_data
=
apply_mistral_chat_template
(
tokenizer
,
messages
=
msgs
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
tools
=
tools
,
)
else
:
prompt_data
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
tools
=
tools
,
)
prompt
:
Union
[
TokensPrompt
,
TextPrompt
]
if
is_list_of
(
prompt_data
,
int
):
prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_data
)
else
:
prompt
=
TextPrompt
(
prompt
=
prompt_data
)
inputs
:
PromptInputs
if
is_list_of
(
prompt
,
int
):
inputs
=
TokensPrompt
(
prompt_token_ids
=
prompt
)
else
:
inputs
=
TextPrompt
(
prompt
=
prompt
)
if
mm_data
is
not
None
:
prompt
[
"multi_modal_data"
]
=
mm_data
if
mm_data
is
not
None
:
inputs
[
"multi_modal_data"
]
=
mm_data
prompts
.
append
(
prompt
)
return
self
.
generate
(
inpu
ts
,
promp
ts
,
sampling_params
=
sampling_params
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
...
...
@@ -628,6 +786,7 @@ class LLM:
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
priority
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
if
isinstance
(
inputs
,
(
str
,
dict
)):
# Convert a single prompt to a list.
...
...
@@ -657,6 +816,7 @@ class LLM:
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
,
Sequence
)
else
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
[
i
]
if
priority
else
0
,
)
def
_add_request
(
...
...
@@ -665,6 +825,7 @@ class LLM:
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
...
...
@@ -673,6 +834,7 @@ class LLM:
params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
)
def
_add_guided_processor
(
...
...
vllm/entrypoints/openai/api_server.py
View file @
539aa992
...
...
@@ -4,16 +4,21 @@ import inspect
import
multiprocessing
import
os
import
re
import
signal
import
socket
import
tempfile
from
argparse
import
Namespace
from
contextlib
import
asynccontextmanager
from
functools
import
partial
from
http
import
HTTPStatus
from
typing
import
AsyncIterator
,
Optional
,
Set
from
typing
import
AsyncIterator
,
Set
import
uvloop
from
fastapi
import
APIRouter
,
FastAPI
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
starlette.datastructures
import
State
from
starlette.routing
import
Mount
from
typing_extensions
import
assert_never
...
...
@@ -21,7 +26,9 @@ import vllm.envs as envs
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
from
vllm.engine.multiprocessing.engine
import
run_mp_engine
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
...
...
@@ -39,12 +46,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeRequest
,
TokenizeResponse
,
UnloadLoraAdapterRequest
)
from
vllm.entrypoints.openai.rpc.client
import
AsyncEngineRPCClient
from
vllm.entrypoints.openai.rpc.server
import
run_rpc_server
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
from
vllm.entrypoints.openai.serving_tokenization
import
(
OpenAIServingTokenization
)
from
vllm.logger
import
init_logger
...
...
@@ -54,12 +60,6 @@ from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
async_engine_client
:
AsyncEngineClient
engine_args
:
AsyncEngineArgs
openai_serving_chat
:
OpenAIServingChat
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_embedding
:
OpenAIServingEmbedding
openai_serving_tokenization
:
OpenAIServingTokenization
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
...
...
@@ -68,49 +68,42 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks
:
Set
[
asyncio
.
Task
]
=
set
()
def
model_is_embedding
(
model_name
:
str
,
trust_remote_code
:
bool
,
quantization
:
Optional
[
str
])
->
bool
:
return
ModelConfig
(
model
=
model_name
,
tokenizer
=
model_name
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
trust_remote_code
,
quantization
=
quantization
,
seed
=
0
,
dtype
=
"auto"
).
embedding_mode
@
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
async
def
_force_log
():
while
True
:
await
asyncio
.
sleep
(
10
)
await
async_engine_client
.
do_log_stats
()
if
not
engine_args
.
disable_log_stats
:
task
=
asyncio
.
create_task
(
_force_log
())
_running_tasks
.
add
(
task
)
task
.
add_done_callback
(
_running_tasks
.
remove
)
yield
try
:
if
app
.
state
.
log_stats
:
engine_client
:
EngineClient
=
app
.
state
.
engine_client
async
def
_force_log
():
while
True
:
await
asyncio
.
sleep
(
10.
)
await
engine_client
.
do_log_stats
()
task
=
asyncio
.
create_task
(
_force_log
())
_running_tasks
.
add
(
task
)
task
.
add_done_callback
(
_running_tasks
.
remove
)
else
:
task
=
None
try
:
yield
finally
:
if
task
is
not
None
:
task
.
cancel
()
finally
:
# Ensure app state including engine ref is gc'd
del
app
.
state
@
asynccontextmanager
async
def
build_async_engine_client
(
args
:
Namespace
)
->
AsyncIterator
[
Optional
[
Async
EngineClient
]
]
:
args
:
Namespace
)
->
AsyncIterator
[
EngineClient
]:
# Context manager to handle
async_
engine_client lifecycle
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global
engine_args
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
# Backend itself still global for the silly lil' health handler
global
async_engine_client
async
with
build_async_engine_client_from_engine_args
(
engine_args
,
args
.
disable_frontend_multiprocessing
)
as
engine
:
async_engine_client
=
engine
# type: ignore[assignment]
yield
engine
...
...
@@ -118,26 +111,35 @@ async def build_async_engine_client(
async
def
build_async_engine_client_from_engine_args
(
engine_args
:
AsyncEngineArgs
,
disable_frontend_multiprocessing
:
bool
=
False
,
)
->
AsyncIterator
[
Optional
[
Async
EngineClient
]
]
:
)
->
AsyncIterator
[
EngineClient
]:
"""
Create
Async
EngineClient, either:
Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if
(
model_is_embedding
(
engine_args
.
model
,
engine_args
.
trust_remote_code
,
engine_args
.
quantization
)
# Fall back
# TODO: fill out feature matrix.
if
(
MQLLMEngineClient
.
is_unsupported_config
(
engine_args
)
or
disable_frontend_multiprocessing
):
engine_client
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
try
:
yield
engine_client
finally
:
engine_client
.
shutdown_background_loop
()
engine_config
=
engine_args
.
create_engine_config
()
uses_ray
=
getattr
(
AsyncLLMEngine
.
_get_executor_cls
(
engine_config
),
"uses_ray"
,
False
)
build_engine
=
partial
(
AsyncLLMEngine
.
from_engine_args
,
engine_args
=
engine_args
,
engine_config
=
engine_config
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
if
uses_ray
:
# Must run in main thread with ray for its signal handlers to work
engine_client
=
build_engine
()
else
:
engine_client
=
await
asyncio
.
get_running_loop
().
run_in_executor
(
None
,
build_engine
)
yield
engine_client
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
...
...
@@ -158,56 +160,58 @@ async def build_async_engine_client_from_engine_args(
"and vLLM will properly handle cleanup."
)
# Select random path for IPC.
rpc_path
=
get_open_zmq_ipc_path
()
logger
.
info
(
"Multiprocessing frontend to use %s for RPC Path."
,
rpc_path
)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client
=
AsyncEngineRPCClient
(
rpc_path
)
ipc_path
=
get_open_zmq_ipc_path
()
logger
.
info
(
"Multiprocessing frontend to use %s for IPC Path."
,
ipc_path
)
# Start RPCServer in separate process (holds the AsyncLLMEngine).
context
=
multiprocessing
.
get_context
(
"spawn"
)
# Start RPCServer in separate process (holds the LLMEngine).
# the current process might have CUDA context,
# so we need to spawn a new process
rpc_server_process
=
context
.
Process
(
target
=
run_rpc_server
,
args
=
(
engine_args
,
UsageContext
.
OPENAI_API_SERVER
,
rpc_path
))
rpc_server_process
.
start
()
logger
.
info
(
"Started engine process with PID %d"
,
rpc_server_process
.
pid
)
context
=
multiprocessing
.
get_context
(
"spawn"
)
engine_process
=
context
.
Process
(
target
=
run_mp_engine
,
args
=
(
engine_args
,
UsageContext
.
OPENAI_API_SERVER
,
ipc_path
))
engine_process
.
start
()
logger
.
info
(
"Started engine process with PID %d"
,
engine_process
.
pid
)
# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config
=
engine_args
.
create_engine_config
()
mp_engine_client
=
MQLLMEngineClient
(
ipc_path
,
engine_config
)
try
:
while
True
:
try
:
await
rpc
_client
.
setup
()
await
mp_engine
_client
.
setup
()
break
except
TimeoutError
:
if
not
rpc_server_process
.
is_alive
():
logger
.
error
(
"RPCServer process died before responding "
"to readiness probe"
)
yield
None
return
yield
rpc_client
# type: ignore[misc]
if
not
engine_process
.
is_alive
():
raise
RuntimeError
(
"Engine process failed to start"
)
from
None
yield
mp_engine_client
# type: ignore[misc]
finally
:
# Ensure rpc server process was terminated
rpc_server
_process
.
terminate
()
engine
_process
.
terminate
()
# Close all open connections to the backend
rpc
_client
.
close
()
mp_engine
_client
.
close
()
# Wait for server process to join
rpc_server_process
.
join
()
# Wait for engine process to join
engine_process
.
join
(
4
)
if
engine_process
.
exitcode
is
None
:
# Kill if taking longer than 5 seconds to stop
engine_process
.
kill
()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from
prometheus_client
import
multiprocess
multiprocess
.
mark_process_dead
(
rpc_server
_process
.
pid
)
multiprocess
.
mark_process_dead
(
engine
_process
.
pid
)
router
=
APIRouter
()
...
...
@@ -239,16 +243,36 @@ def mount_metrics(app: FastAPI):
app
.
routes
.
append
(
metrics_route
)
def
chat
(
request
:
Request
)
->
OpenAIServingChat
:
return
request
.
app
.
state
.
openai_serving_chat
def
completion
(
request
:
Request
)
->
OpenAIServingCompletion
:
return
request
.
app
.
state
.
openai_serving_completion
def
tokenization
(
request
:
Request
)
->
OpenAIServingTokenization
:
return
request
.
app
.
state
.
openai_serving_tokenization
def
embedding
(
request
:
Request
)
->
OpenAIServingEmbedding
:
return
request
.
app
.
state
.
openai_serving_embedding
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
@
router
.
get
(
"/health"
)
async
def
health
()
->
Response
:
async
def
health
(
raw_request
:
Request
)
->
Response
:
"""Health check."""
await
async_
engine_client
.
check_health
()
await
engine_client
(
raw_request
)
.
check_health
()
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/tokenize"
)
async
def
tokenize
(
request
:
TokenizeRequest
):
generator
=
await
openai_serving_tokenization
.
create_tokenize
(
request
)
async
def
tokenize
(
request
:
TokenizeRequest
,
raw_request
:
Request
):
generator
=
await
tokenization
(
raw_request
)
.
create_tokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
...
...
@@ -259,8 +283,8 @@ async def tokenize(request: TokenizeRequest):
@
router
.
post
(
"/detokenize"
)
async
def
detokenize
(
request
:
DetokenizeRequest
):
generator
=
await
openai_serving_tokenization
.
create_detokenize
(
request
)
async
def
detokenize
(
request
:
DetokenizeRequest
,
raw_request
:
Request
):
generator
=
await
tokenization
(
raw_request
)
.
create_detokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
...
...
@@ -271,8 +295,8 @@ async def detokenize(request: DetokenizeRequest):
@
router
.
get
(
"/v1/models"
)
async
def
show_available_models
():
models
=
await
openai_serving_completion
.
show_available_models
()
async
def
show_available_models
(
raw_request
:
Request
):
models
=
await
completion
(
raw_request
)
.
show_available_models
()
return
JSONResponse
(
content
=
models
.
model_dump
())
...
...
@@ -286,7 +310,7 @@ async def show_version():
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_chat
.
create_chat_completion
(
generator
=
await
chat
(
raw_request
)
.
create_chat_completion
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
...
...
@@ -301,7 +325,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@
router
.
post
(
"/v1/completions"
)
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_completion
.
create_completion
(
generator
=
await
completion
(
raw_request
)
.
create_completion
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
...
...
@@ -314,7 +338,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@
router
.
post
(
"/v1/embeddings"
)
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_embedding
.
create_embedding
(
generator
=
await
embedding
(
raw_request
)
.
create_embedding
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
...
...
@@ -331,16 +355,16 @@ if envs.VLLM_TORCH_PROFILER_DIR:
"used for local development!"
)
@
router
.
post
(
"/start_profile"
)
async
def
start_profile
():
async
def
start_profile
(
raw_request
:
Request
):
logger
.
info
(
"Starting profiler..."
)
await
async_
engine_client
.
start_profile
()
await
engine_client
(
raw_request
)
.
start_profile
()
logger
.
info
(
"Profiler started."
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/stop_profile"
)
async
def
stop_profile
():
async
def
stop_profile
(
raw_request
:
Request
):
logger
.
info
(
"Stopping profiler..."
)
await
async_
engine_client
.
stop_profile
()
await
engine_client
(
raw_request
)
.
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
return
Response
(
status_code
=
200
)
...
...
@@ -351,13 +375,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"This should ONLY be used for local development!"
)
@
router
.
post
(
"/v1/load_lora_adapter"
)
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
):
response
=
await
openai_serving_chat
.
load_lora_adapter
(
request
)
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
,
raw_request
:
Request
):
response
=
await
chat
(
raw_request
).
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
load_lora_adapter
(
request
)
response
=
await
completion
(
raw_request
)
.
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
...
...
@@ -365,13 +390,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
return
Response
(
status_code
=
200
,
content
=
response
)
@
router
.
post
(
"/v1/unload_lora_adapter"
)
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
):
response
=
await
openai_serving_chat
.
unload_lora_adapter
(
request
)
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
,
raw_request
:
Request
):
response
=
await
chat
(
raw_request
).
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
unload_lora_adapter
(
request
)
response
=
await
completion
(
raw_request
)
.
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
...
...
@@ -380,7 +406,13 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
app
=
FastAPI
(
lifespan
=
lifespan
)
if
args
.
disable_fastapi_docs
:
app
=
FastAPI
(
openapi_url
=
None
,
docs_url
=
None
,
redoc_url
=
None
,
lifespan
=
lifespan
)
else
:
app
=
FastAPI
(
lifespan
=
lifespan
)
app
.
include_router
(
router
)
app
.
root_path
=
args
.
root_path
...
...
@@ -396,7 +428,8 @@ def build_app(args: Namespace) -> FastAPI:
@
app
.
exception_handler
(
RequestValidationError
)
async
def
validation_exception_handler
(
_
,
exc
):
err
=
openai_serving_chat
.
create_error_response
(
message
=
str
(
exc
))
chat
=
app
.
state
.
openai_serving_chat
err
=
chat
.
create_error_response
(
message
=
str
(
exc
))
return
JSONResponse
(
err
.
model_dump
(),
status_code
=
HTTPStatus
.
BAD_REQUEST
)
...
...
@@ -428,33 +461,34 @@ def build_app(args: Namespace) -> FastAPI:
return
app
async
def
init_app
(
async_engine_client
:
AsyncEngineClient
,
def
init_app_state
(
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
state
:
State
,
args
:
Namespace
,
)
->
FastAPI
:
app
=
build_app
(
args
)
)
->
None
:
if
args
.
served_model_name
is
not
None
:
served_model_names
=
args
.
served_model_name
else
:
served_model_names
=
[
args
.
model
]
model_config
=
await
async_engine_client
.
get_model_config
()
if
args
.
disable_log_requests
:
request_logger
=
None
else
:
request_logger
=
RequestLogger
(
max_log_len
=
args
.
max_log_len
)
global
openai_serving_chat
global
openai_serving_completion
global
openai_serving_embedding
global
openai_serving_tokenization
base_model_paths
=
[
BaseModelPath
(
name
=
name
,
model_path
=
args
.
model
)
for
name
in
served_model_names
]
openai_serving_chat
=
OpenAIServingChat
(
async_engine_client
,
state
.
engine_client
=
engine_client
state
.
log_stats
=
not
args
.
disable_log_stats
state
.
openai_serving_chat
=
OpenAIServingChat
(
engine_client
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
args
.
response_role
,
lora_modules
=
args
.
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
...
...
@@ -463,48 +497,54 @@ async def init_app(
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
enable_auto_tools
=
args
.
enable_auto_tool_choice
,
tool_parser
=
args
.
tool_call_parser
)
openai_serving_completion
=
OpenAIServingCompletion
(
async_
engine_client
,
state
.
openai_serving_completion
=
OpenAIServingCompletion
(
engine_client
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
lora_modules
=
args
.
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
async_
engine_client
,
state
.
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine_client
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
request_logger
=
request_logger
,
)
openai_serving_tokenization
=
OpenAIServingTokenization
(
async_
engine_client
,
state
.
openai_serving_tokenization
=
OpenAIServingTokenization
(
engine_client
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
lora_modules
=
args
.
lora_modules
,
request_logger
=
request_logger
,
chat_template
=
args
.
chat_template
,
)
app
.
root_path
=
args
.
root_path
return
app
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
async
with
build_async_engine_client
(
args
)
as
async_engine_client
:
# If None, creation of the client failed and we exit.
if
async_engine_client
is
None
:
return
temp_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
temp_socket
.
bind
((
""
,
args
.
port
))
def
signal_handler
(
*
_
)
->
None
:
# Interrupt server on sigterm while initializing
raise
KeyboardInterrupt
(
"terminated"
)
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
async
with
build_async_engine_client
(
args
)
as
engine_client
:
app
=
build_app
(
args
)
model_config
=
await
engine_client
.
get_model_config
()
init_app_state
(
engine_client
,
model_config
,
app
.
state
,
args
)
app
=
await
init_app
(
async_engine_client
,
args
)
temp_socket
.
close
(
)
shutdown_task
=
await
serve_http
(
app
,
engine
=
async_engine_client
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
uvicorn_log_level
,
...
...
@@ -528,4 +568,4 @@ if __name__ == "__main__":
parser
=
make_arg_parser
(
parser
)
args
=
parser
.
parse_args
()
asyncio
.
run
(
run_server
(
args
))
uvloop
.
run
(
run_server
(
args
))
vllm/entrypoints/openai/cli_args.py
View file @
539aa992
...
...
@@ -31,8 +31,23 @@ class LoRAParserAction(argparse.Action):
lora_list
:
List
[
LoRAModulePath
]
=
[]
for
item
in
values
:
name
,
path
=
item
.
split
(
'='
)
lora_list
.
append
(
LoRAModulePath
(
name
,
path
))
if
item
in
[
None
,
''
]:
# Skip if item is None or empty string
continue
if
'='
in
item
and
','
not
in
item
:
# Old format: name=path
name
,
path
=
item
.
split
(
'='
)
lora_list
.
append
(
LoRAModulePath
(
name
,
path
))
else
:
# Assume JSON format
try
:
lora_dict
=
json
.
loads
(
item
)
lora
=
LoRAModulePath
(
**
lora_dict
)
lora_list
.
append
(
lora
)
except
json
.
JSONDecodeError
:
parser
.
error
(
f
"Invalid JSON format for --lora-modules:
{
item
}
"
)
except
TypeError
as
e
:
parser
.
error
(
f
"Invalid fields for --lora-modules:
{
item
}
-
{
str
(
e
)
}
"
)
setattr
(
namespace
,
self
.
dest
,
lora_list
)
...
...
@@ -95,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default
=
None
,
nargs
=
'+'
,
action
=
LoRAParserAction
,
help
=
"LoRA module configurations in the format name=path. "
"Multiple modules can be specified."
)
help
=
"LoRA module configurations in either 'name=path' format"
"or JSON format. "
"Example (old format): 'name=path' "
"Example (new format): "
"'{
\"
name
\"
:
\"
name
\"
,
\"
local_path
\"
:
\"
path
\"
, "
"
\"
base_model_name
\"
:
\"
id
\"
}'"
)
parser
.
add_argument
(
"--prompt-adapters"
,
type
=
nullable_str
,
...
...
@@ -190,6 +209,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'ID numbers being printed in log.'
'
\n\n
Default: Unlimited'
)
parser
.
add_argument
(
"--disable-fastapi-docs"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
)
return
parser
...
...
vllm/entrypoints/openai/protocol.py
View file @
539aa992
...
...
@@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel):
completion_tokens
:
Optional
[
int
]
=
0
class
RequestResponseMetadata
(
BaseModel
):
request_id
:
str
final_usage_info
:
Optional
[
UsageInfo
]
=
None
class
JsonSchemaResponseFormat
(
OpenAIBaseModel
):
name
:
str
description
:
Optional
[
str
]
=
None
...
...
vllm/entrypoints/openai/rpc/client.py
deleted
100644 → 0
View file @
93872128
import
asyncio
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
Any
,
AsyncGenerator
,
Iterator
,
Mapping
,
Optional
from
uuid
import
uuid4
import
cloudpickle
import
zmq
import
zmq.asyncio
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
# yapf: disable
from
vllm.entrypoints.openai.rpc
import
(
RPC_REQUEST_TYPE
,
VLLM_RPC_SOCKET_LIMIT_CUTOFF
,
VLLM_RPC_SUCCESS_STR
,
VLLM_RPC_ZMQ_HWM
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_GET_DATA_TIMEOUT_MS
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
logger
=
init_logger
(
__name__
)
# Path used for inprocess proxy.
INPROC_PROXY_PATH
=
f
"inproc://
{
uuid4
()
}
"
class
RPCClientClosedError
(
Exception
):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class
AsyncEngineRPCClient
:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
The overall design mirrors the Asynchronous Client Server Pattern
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
On startup, the RPCClient:
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
via ipc, which uses unix sockets under the hood
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
- makes ROUTER socket (from_api_server) that binds to a random
inproc address, which uses memory under the hood
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
- runs a proxy in a background asyncio task between
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
Each request handled by the asyncio api_server calls generate():
- make a DEALER socket that connects to from_api_server via inproc
- send a RCPGenerateRequest to the inproc socket
- background proxy forwards the request from inproc -> ipc
- RPCServer responds to the request one token at a time over ipc
- background proxy forwards the response from ipc -> inproc
The connection looks like this:
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
Message routing is performed via identities that are managed by the
ROUTER socket. ROUTER sockets track every connection it has and
tells the caller about these. The way it tells the caller is to stick
the connection identity in front of each message received. When we
send the message via a ROUTER, we first send an identity frame.
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
for more details on connection identities.
This proxy design enables us to use a single unix socket, which
improves performance by avoiding syscalls (~5%) and avoids resource limits
such as ulimit, which defaults to 1024 on ubuntu.
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
which is required to avoid dropping messages under high load.
This is generally not advisable. However, since we are in control
of both sides of the connection + failure on either side is
catastrophic to the overall system health and memory profiling
suggests limited memory overhead relative to asyncio, we will
proceed for now.
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
for more details on high water marks.
"""
def
__init__
(
self
,
rpc_path
:
str
):
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
_data_timeout
=
VLLM_RPC_GET_DATA_TIMEOUT_MS
self
.
_errored
=
False
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit
=
self
.
context
.
get
(
zmq
.
constants
.
SOCKET_LIMIT
)
assert
isinstance
(
socket_limit
,
int
)
if
socket_limit
<
VLLM_RPC_SOCKET_LIMIT_CUTOFF
:
raise
ValueError
(
f
"Found zmq.constants.SOCKET_LIMIT=
{
socket_limit
}
, which caps "
"the number of concurrent requests vLLM can process. Launch "
"vLLM with --disable-frontend-multiprocessing and open a "
"GitHub issue so we can investigate."
)
# We only have 1 ipc connection that uses unix sockets, so
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
# not run into ulimit issues)
self
.
context
.
set
(
zmq
.
constants
.
MAX_SOCKETS
,
socket_limit
)
# IPC connection to RPC Server (uses unix sockets).
self
.
to_rpc_server
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
to_rpc_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
to_rpc_server
.
bind
(
rpc_path
)
# In process proxy to RPC Server (uses memory-based messaging).
self
.
from_api_server
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
self
.
from_api_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
from_api_server
.
bind
(
INPROC_PROXY_PATH
)
# Asyncio background task for the proxy.
self
.
proxy_in_task
=
asyncio
.
create_task
(
self
.
run_proxy
(
self
.
from_api_server
,
self
.
to_rpc_server
))
self
.
proxy_out_task
=
asyncio
.
create_task
(
self
.
run_proxy
(
self
.
to_rpc_server
,
self
.
from_api_server
))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
# mulitprocessing. This value is used uvicorn to launch
# with --limit-concurrency to return 503 when server is overloaded.
# We need 2 sockets per request - 2:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self
.
limit_concurrency
=
socket_limit
//
2
-
2
async
def
run_proxy
(
self
,
socket_from
:
Socket
,
socket_to
:
Socket
):
"""Background task that runs a proxy"""
while
True
:
frames
=
await
socket_from
.
recv_multipart
(
copy
=
False
)
await
socket_to
.
send_multipart
(
frames
,
copy
=
False
)
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
await
self
.
_wait_for_server_rpc
()
# Get the configs.
self
.
model_config
=
await
self
.
_get_model_config_rpc
()
self
.
decoding_config
=
await
self
.
_get_decoding_config_rpc
()
self
.
tracing_flag
=
await
self
.
_is_tracing_enabled_rpc
()
# Create the tokenizer group.
# TODO: refactor OAI server to avoid needing this info.
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
self
.
model_config
,
scheduler_config
=
(
await
self
.
_get_scheduler_config_rpc
()),
parallel_config
=
(
await
self
.
_get_parallel_config_rpc
()),
enable_lora
=
bool
(
await
self
.
_get_lora_config_rpc
()),
)
def
close
(
self
):
"""Destroy the ZeroMQ Context."""
# Close all sockets associated with this context and
# then terminate the context.
self
.
from_api_server
.
close
()
self
.
to_rpc_server
.
close
()
self
.
context
.
destroy
()
@
contextmanager
def
to_proxy_socket
(
self
)
->
Iterator
[
Socket
]:
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if
self
.
context
.
closed
:
raise
RPCClientClosedError
(
"The ZMQ client has already shut down"
)
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
socket
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
try
:
socket
.
connect
(
INPROC_PROXY_PATH
)
yield
socket
finally
:
socket
.
close
(
linger
=
0
)
async
def
_send_get_data_rpc_request
(
self
,
request
:
RPCUtilityRequest
,
expected_type
:
Any
,
error_message
:
str
)
->
Any
:
"""Send an RPC request that is expecting data back."""
with
self
.
to_proxy_socket
()
as
socket
:
# Ping RPCServer with a request.
await
socket
.
send_multipart
((
cloudpickle
.
dumps
(
request
),
),
copy
=
False
)
# Make sure the server responds
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
raise
TimeoutError
(
"Server didn't reply within "
f
"
{
self
.
_data_timeout
}
ms"
)
# Await the data from the Server.
frame
=
await
socket
.
recv
(
copy
=
False
)
assert
isinstance
(
frame
,
Frame
)
data
=
pickle
.
loads
(
frame
.
buffer
)
if
isinstance
(
data
,
Exception
):
# Re-raise exceptions returned by the server
raise
data
if
not
isinstance
(
data
,
expected_type
):
# LoRAConfig can be None.
if
expected_type
==
LoRAConfig
and
data
is
None
:
pass
elif
isinstance
(
data
,
Exception
):
logger
.
error
(
error_message
)
raise
data
else
:
raise
ValueError
(
error_message
)
return
data
async
def
_send_one_way_rpc_request
(
self
,
request
:
RPC_REQUEST_TYPE
,
error_message
:
str
,
socket
:
Optional
[
Socket
]
=
None
):
"""Send one-way RPC request to trigger an action."""
async
def
do_rpc_call
(
socket
:
Socket
,
request
:
RPC_REQUEST_TYPE
):
await
socket
.
send_multipart
((
cloudpickle
.
dumps
(
request
),
))
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
raise
TimeoutError
(
"Server didn't reply within "
f
"
{
self
.
_data_timeout
}
ms"
)
frame
=
await
socket
.
recv
(
copy
=
False
)
assert
isinstance
(
frame
,
Frame
)
return
pickle
.
loads
(
frame
.
buffer
)
# Make a new socket connection.
if
socket
is
None
:
with
self
.
to_proxy_socket
()
as
socket
:
response
=
await
do_rpc_call
(
socket
,
request
)
# Use existing socket connection.
else
:
response
=
await
do_rpc_call
(
socket
,
request
)
if
not
isinstance
(
response
,
str
)
or
response
!=
VLLM_RPC_SUCCESS_STR
:
if
isinstance
(
response
,
Exception
):
logger
.
error
(
error_message
)
raise
response
raise
ValueError
(
error_message
)
async
def
get_tokenizer
(
self
,
lora_request
:
LoRARequest
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
return
self
.
decoding_config
async
def
get_model_config
(
self
)
->
ModelConfig
:
return
self
.
model_config
async
def
is_tracing_enabled
(
self
)
->
bool
:
return
self
.
tracing_flag
async
def
_wait_for_server_rpc
(
self
):
"""Wait for the RPCServer to start up."""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_READY
,
error_message
=
"Unable to start RPC Server"
)
async
def
_get_model_config_rpc
(
self
)
->
ModelConfig
:
"""Get the ModelConfig object from the RPC Server"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_MODEL_CONFIG
,
expected_type
=
ModelConfig
,
error_message
=
"Could not get ModelConfig from RPC Server"
)
async
def
_get_decoding_config_rpc
(
self
)
->
DecodingConfig
:
"""Get DecodingConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_DECODING_CONFIG
,
expected_type
=
DecodingConfig
,
error_message
=
"Could not get DecodingConfig from RPC Server"
)
async
def
_get_parallel_config_rpc
(
self
)
->
ParallelConfig
:
"""Get ParallelConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
,
expected_type
=
ParallelConfig
,
error_message
=
"Could not get ParallelConfig from RPC Server"
)
async
def
_get_scheduler_config_rpc
(
self
)
->
SchedulerConfig
:
"""Get SchedulerConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
,
expected_type
=
SchedulerConfig
,
error_message
=
"Could not get SchedulerConfig from RPC Server"
)
async
def
_get_lora_config_rpc
(
self
)
->
LoRAConfig
:
"""Get LoRAConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_LORA_CONFIG
,
expected_type
=
LoRAConfig
,
error_message
=
"Could not get LoRAConfig from RPC Server"
)
async
def
_is_tracing_enabled_rpc
(
self
)
->
bool
:
"""Get is_tracing_enabled flag from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
IS_TRACING_ENABLED
,
expected_type
=
bool
,
error_message
=
"Could not get is_tracing_enabled from RPC Server"
)
async
def
abort
(
self
,
request_id
:
str
):
"""Send an ABORT_REQUEST signal to the RPC Server"""
# Suppress timeouts as well.
# In cases where the server is busy processing requests and a very
# large volume of abort requests arrive, it is likely that the server
# will not be able to ack all of them in time. We have seen this when
# we abort 20k requests at once while another 2k are processing- many
# of them time out, but we see the server successfully abort all of the
# requests.
# In this case we assume that the server has received or will receive
# these abort requests, and ignore the timeout. This prevents a massive
# wall of `TimeoutError` stack traces.
with
suppress
(
RPCClientClosedError
,
TimeoutError
):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCAbortRequest
(
request_id
),
error_message
=
f
"RPCAbortRequest
{
request_id
}
failed"
)
async
def
do_log_stats
(
self
):
"""Send a DO_LOG_STATS signal to the RPC Server"""
with
suppress
(
RPCClientClosedError
):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
DO_LOG_STATS
,
error_message
=
"RPCRequest DO_LOG_STATS failed."
)
@
property
def
is_running
(
self
)
->
bool
:
return
not
self
.
_errored
@
property
def
is_stopped
(
self
)
->
bool
:
return
self
.
_errored
@
property
def
errored
(
self
)
->
bool
:
return
self
.
_errored
async
def
generate
(
self
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
finished
=
False
try
:
with
self
.
to_proxy_socket
()
as
socket
:
# Send RPCGenerateRequest to the RPCServer.
await
socket
.
send_multipart
((
cloudpickle
.
dumps
(
RPCGenerateRequest
(
inputs
=
inputs
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
)),
))
# Stream back the results from the RPC Server.
while
not
finished
:
message
=
await
socket
.
recv
(
copy
=
False
)
assert
isinstance
(
message
,
Frame
)
request_output
=
pickle
.
loads
(
message
.
buffer
)
if
isinstance
(
request_output
,
Exception
):
# On exception, check if the server is still healthy
# possibly setting the `errored` property.
if
not
self
.
_errored
:
try
:
await
self
.
check_health
(
socket
=
socket
)
except
Exception
as
e
:
self
.
_errored
=
True
logger
.
exception
(
repr
(
e
))
# NB: do before raising here so that the flag is set
# by the time the caller receives this exception
raise
request_output
finished
=
request_output
.
finished
yield
request_output
finally
:
# Request was canceled by the client.
if
not
finished
and
not
self
.
_errored
:
await
self
.
abort
(
request_id
)
async
def
check_health
(
self
,
socket
:
Optional
[
Socket
]
=
None
)
->
None
:
"""Raise if unhealthy"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_HEALTHY
,
error_message
=
"Got Unhealthy response from RPC Server"
,
socket
=
socket
)
async
def
encode
(
self
,
*
args
,
**
kwargs
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
raise
NotImplementedError
(
"Embeddings not supported with multiprocessing backend"
)
async
def
start_profile
(
self
)
->
None
:
"""Start profiling the engine"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
START_PROFILE
,
error_message
=
"RPCRequest START_PROFILE failed."
)
async
def
stop_profile
(
self
)
->
None
:
"""Stop profiling the engine"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
STOP_PROFILE
,
error_message
=
"RPCRequest STOP_PROFILE failed."
)
vllm/entrypoints/openai/rpc/server.py
deleted
100644 → 0
View file @
93872128
import
asyncio
import
pickle
import
signal
from
typing
import
Any
,
Coroutine
,
Union
import
cloudpickle
import
uvloop
import
zmq
import
zmq.asyncio
from
typing_extensions
import
Never
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.entrypoints.openai.rpc
import
(
VLLM_RPC_SUCCESS_STR
,
VLLM_RPC_ZMQ_HWM
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
logger
=
init_logger
(
__name__
)
CONFIG_TYPE
=
Union
[
ModelConfig
,
DecodingConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
]
class
AsyncEngineRPCServer
:
def
__init__
(
self
,
async_engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
rpc_path
:
str
):
# Initialize engine first.
self
.
engine
=
AsyncLLMEngine
.
from_engine_args
(
async_engine_args
,
usage_context
=
usage_context
)
# Initialize context.
self
.
context
=
zmq
.
asyncio
.
Context
()
# Init socket.
self
.
socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
socket
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
socket
.
connect
(
rpc_path
)
def
cleanup
(
self
):
"""Cleanup all resources."""
self
.
socket
.
close
()
self
.
context
.
destroy
()
self
.
engine
.
shutdown_background_loop
()
# Clear the engine reference so that it can be GC'ed.
del
self
.
engine
async
def
get_config
(
self
,
identity
,
request
):
try
:
config
:
CONFIG_TYPE
if
request
==
RPCUtilityRequest
.
GET_MODEL_CONFIG
:
config
=
await
self
.
engine
.
get_model_config
()
elif
request
==
RPCUtilityRequest
.
GET_DECODING_CONFIG
:
config
=
await
self
.
engine
.
get_decoding_config
()
elif
request
==
RPCUtilityRequest
.
GET_LORA_CONFIG
:
config
=
await
self
.
engine
.
get_lora_config
()
elif
request
==
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
:
config
=
await
self
.
engine
.
get_scheduler_config
()
elif
request
==
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
:
config
=
await
self
.
engine
.
get_parallel_config
()
else
:
raise
ValueError
(
"Unknown Config Request: %s"
,
request
)
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
config
)),
copy
=
False
)
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
is_tracing_enabled
(
self
,
identity
):
"""Send the is_tracing_enabled flag"""
tracing_flag
=
await
self
.
engine
.
is_tracing_enabled
()
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
tracing_flag
)))
async
def
do_log_stats
(
self
,
identity
):
"""Log stats and confirm success."""
await
self
.
engine
.
do_log_stats
()
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)))
async
def
is_server_ready
(
self
,
identity
):
"""Notify the client that we are ready."""
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)))
async
def
abort
(
self
,
identity
,
request
:
RPCAbortRequest
):
"""Abort request and notify the client of success."""
try
:
# Abort the request in the llm engine.
await
self
.
engine
.
abort
(
request
.
request_id
)
result
:
Union
[
str
,
Exception
]
=
VLLM_RPC_SUCCESS_STR
except
Exception
as
e
:
result
=
e
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
result
)))
async
def
generate
(
self
,
identity
,
generate_request
:
RPCGenerateRequest
):
try
:
results_generator
=
self
.
engine
.
generate
(
generate_request
.
inputs
,
sampling_params
=
generate_request
.
sampling_params
,
request_id
=
generate_request
.
request_id
,
lora_request
=
generate_request
.
lora_request
,
trace_headers
=
generate_request
.
trace_headers
,
prompt_adapter_request
=
generate_request
.
prompt_adapter_request
)
async
for
request_output
in
results_generator
:
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
request_output
)),
copy
=
False
)
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
check_health
(
self
,
identity
):
try
:
await
self
.
engine
.
check_health
()
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)))
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
start_profile
(
self
,
identity
):
logger
.
info
(
"Starting profiler..."
)
await
self
.
engine
.
start_profile
()
logger
.
info
(
"Profiler started."
)
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
))
async
def
stop_profile
(
self
,
identity
):
logger
.
info
(
"Stopping profiler..."
)
await
self
.
engine
.
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
))
def
_make_handler_coro
(
self
,
identity
,
message
:
Frame
)
->
Coroutine
[
Any
,
Any
,
Never
]:
"""Route the zmq message to the handler coroutine."""
request
=
cloudpickle
.
loads
(
message
.
buffer
)
if
isinstance
(
request
,
RPCGenerateRequest
):
return
self
.
generate
(
identity
,
request
)
elif
isinstance
(
request
,
RPCAbortRequest
):
return
self
.
abort
(
identity
,
request
)
elif
isinstance
(
request
,
RPCUtilityRequest
):
if
request
in
[
RPCUtilityRequest
.
GET_MODEL_CONFIG
,
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
,
RPCUtilityRequest
.
GET_DECODING_CONFIG
,
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
,
RPCUtilityRequest
.
GET_LORA_CONFIG
]:
return
self
.
get_config
(
identity
,
request
)
elif
request
==
RPCUtilityRequest
.
DO_LOG_STATS
:
return
self
.
do_log_stats
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_SERVER_READY
:
return
self
.
is_server_ready
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_SERVER_HEALTHY
:
return
self
.
check_health
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_TRACING_ENABLED
:
return
self
.
is_tracing_enabled
(
identity
)
elif
request
==
RPCUtilityRequest
.
START_PROFILE
:
return
self
.
start_profile
(
identity
)
elif
request
==
RPCUtilityRequest
.
STOP_PROFILE
:
return
self
.
stop_profile
(
identity
)
else
:
raise
ValueError
(
f
"Unknown RPCUtilityRequest type:
{
request
}
"
)
else
:
raise
ValueError
(
f
"Unknown RPCRequest type:
{
request
}
"
)
async
def
run_server_loop
(
self
):
"""Inner RPC Server Loop"""
running_tasks
=
set
()
while
True
:
# Wait for a request.
identity
,
message
=
await
self
.
socket
.
recv_multipart
(
copy
=
False
)
# Process the request async.
task
=
asyncio
.
create_task
(
self
.
_make_handler_coro
(
identity
,
message
))
# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
running_tasks
.
add
(
task
)
task
.
add_done_callback
(
running_tasks
.
discard
)
async
def
run_server
(
server
:
AsyncEngineRPCServer
):
# Put the server task into the asyncio loop.
loop
=
asyncio
.
get_running_loop
()
server_task
=
loop
.
create_task
(
server
.
run_server_loop
())
# Interruption handling.
def
signal_handler
()
->
None
:
# Kill the server on interrupt / terminate
server_task
.
cancel
()
loop
.
add_signal_handler
(
signal
.
SIGINT
,
signal_handler
)
loop
.
add_signal_handler
(
signal
.
SIGTERM
,
signal_handler
)
try
:
await
server_task
except
asyncio
.
CancelledError
:
logger
.
info
(
"vLLM ZMQ RPC Server was interrupted."
)
finally
:
# Clean up all resources.
server
.
cleanup
()
def
run_rpc_server
(
async_engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
rpc_path
:
str
):
server
=
AsyncEngineRPCServer
(
async_engine_args
,
usage_context
,
rpc_path
)
uvloop
.
run
(
run_server
(
server
))
vllm/entrypoints/openai/run_batch.py
View file @
539aa992
...
...
@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.version
import
__version__
as
VLLM_VERSION
...
...
@@ -196,6 +197,10 @@ async def main(args):
engine_args
,
usage_context
=
UsageContext
.
OPENAI_BATCH_RUNNER
)
model_config
=
await
engine
.
get_model_config
()
base_model_paths
=
[
BaseModelPath
(
name
=
name
,
model_path
=
args
.
model
)
for
name
in
served_model_names
]
if
args
.
disable_log_requests
:
request_logger
=
None
...
...
@@ -206,7 +211,7 @@ async def main(args):
openai_serving_chat
=
OpenAIServingChat
(
engine
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
args
.
response_role
,
lora_modules
=
None
,
prompt_adapters
=
None
,
...
...
@@ -216,7 +221,7 @@ async def main(args):
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
request_logger
=
request_logger
,
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
539aa992
...
...
@@ -9,7 +9,7 @@ from typing import Union
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
Async
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
...
...
@@ -22,8 +22,10 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ErrorResponse
,
FunctionCall
,
ToolCall
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
DeltaToolCall
,
ErrorResponse
,
FunctionCall
,
RequestResponseMetadata
,
ToolCall
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
LoRAModulePath
,
OpenAIServing
,
PromptAdapterPath
,
TextTokensPrompt
)
...
...
@@ -45,9 +47,9 @@ logger = init_logger(__name__)
class
OpenAIServingChat
(
OpenAIServing
):
def
__init__
(
self
,
async_
engine_client
:
Async
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
se
rved
_model_
name
s
:
List
[
str
],
ba
se_model_
path
s
:
List
[
BaseModelPath
],
response_role
:
str
,
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
...
...
@@ -57,9 +59,9 @@ class OpenAIServingChat(OpenAIServing):
return_tokens_as_token_ids
:
bool
=
False
,
enable_auto_tools
:
bool
=
False
,
tool_parser
:
Optional
[
str
]
=
None
):
super
().
__init__
(
async_
engine_client
=
async_
engine_client
,
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
se
rved
_model_
names
=
served
_model_
name
s
,
ba
se_model_
paths
=
base
_model_
path
s
,
lora_modules
=
lora_modules
,
prompt_adapters
=
prompt_adapters
,
request_logger
=
request_logger
,
...
...
@@ -105,6 +107,12 @@ class OpenAIServingChat(OpenAIServing):
logger
.
error
(
"Error with model %s"
,
error_check_ret
)
return
error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if
self
.
engine_client
.
errored
:
raise
self
.
engine_client
.
dead_error
try
:
(
lora_request
,
...
...
@@ -112,8 +120,7 @@ class OpenAIServingChat(OpenAIServing):
)
=
self
.
_maybe_get_adapters
(
request
)
model_config
=
self
.
model_config
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
conversation
,
mm_data_future
=
parse_chat_messages_futures
(
request
.
messages
,
model_config
,
tokenizer
)
...
...
@@ -123,7 +130,8 @@ class OpenAIServingChat(OpenAIServing):
]
prompt
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
is_mistral_tokenizer
=
isinstance
(
tokenizer
,
MistralTokenizer
)
if
is_mistral_tokenizer
:
prompt
=
apply_mistral_chat_template
(
tokenizer
,
messages
=
request
.
messages
,
...
...
@@ -159,15 +167,20 @@ class OpenAIServingChat(OpenAIServing):
return
self
.
create_error_response
(
"tool_choice =
\"
required
\"
is not supported!"
)
# "auto" tools requires --enable-auto-tool-choice
# and --tool-call-parser
if
request
.
tool_choice
==
"auto"
and
not
(
if
not
is_mistral_tokenizer
and
request
.
tool_choice
==
"auto"
and
not
(
self
.
enable_auto_tools
and
self
.
tool_parser
is
not
None
):
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return
self
.
create_error_response
(
"
\"
auto
\"
tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set"
)
request_id
=
f
"chat-
{
random_uuid
()
}
"
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
if
raw_request
:
raw_request
.
state
.
request_metadata
=
request_metadata
try
:
guided_decode_logits_processor
=
(
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
...
...
@@ -206,8 +219,8 @@ class OpenAIServingChat(OpenAIServing):
if
mm_data
is
not
None
:
engine_inputs
[
"multi_modal_data"
]
=
mm_data
is_tracing_enabled
=
(
await
self
.
async_
engine_client
.
is_tracing_enabled
())
is_tracing_enabled
=
(
await
self
.
engine_client
.
is_tracing_enabled
())
trace_headers
=
None
if
is_tracing_enabled
and
raw_request
:
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
...
...
@@ -215,7 +228,7 @@ class OpenAIServingChat(OpenAIServing):
and
contains_trace_headers
(
raw_request
.
headers
)):
log_tracing_disabled_warning
()
result_generator
=
self
.
async_
engine_client
.
generate
(
result_generator
=
self
.
engine_client
.
generate
(
engine_inputs
,
sampling_params
,
request_id
,
...
...
@@ -234,11 +247,13 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
,
request_metadata
)
try
:
return
await
self
.
chat_completion_full_generator
(
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
,
request_metadata
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
...
...
@@ -255,8 +270,9 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
conversation
:
List
[
ConversationMessage
],
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
AsyncGenerator
[
str
,
None
]:
model_name
=
self
.
se
rved
_model_name
s
[
0
]
model_name
=
self
.
ba
se_model_
paths
[
0
].
name
created_time
=
int
(
time
.
time
())
chunk_object_type
:
Final
=
"chat.completion.chunk"
first_iteration
=
True
...
...
@@ -293,6 +309,8 @@ class OpenAIServingChat(OpenAIServing):
async
for
res
in
result_generator
:
if
res
.
prompt_token_ids
is
not
None
:
num_prompt_tokens
=
len
(
res
.
prompt_token_ids
)
if
res
.
encoder_prompt_token_ids
is
not
None
:
num_prompt_tokens
+=
len
(
res
.
encoder_prompt_token_ids
)
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
...
...
@@ -573,6 +591,13 @@ class OpenAIServingChat(OpenAIServing):
exclude_unset
=
True
,
exclude_none
=
True
))
yield
f
"data:
{
final_usage_data
}
\n\n
"
# report to FastAPI middleware aggregate usage across all choices
num_completion_tokens
=
sum
(
previous_num_tokens
)
request_metadata
.
final_usage_info
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_completion_tokens
,
total_tokens
=
num_prompt_tokens
+
num_completion_tokens
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
logger
.
error
(
"error in chat completion stream generator: %s"
,
e
)
...
...
@@ -588,9 +613,10 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
conversation
:
List
[
ConversationMessage
],
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
Union
[
ErrorResponse
,
ChatCompletionResponse
]:
model_name
=
self
.
se
rved
_model_name
s
[
0
]
model_name
=
self
.
ba
se_model_
paths
[
0
].
name
created_time
=
int
(
time
.
time
())
final_res
:
Optional
[
RequestOutput
]
=
None
...
...
@@ -707,6 +733,9 @@ class OpenAIServingChat(OpenAIServing):
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
request_metadata
.
final_usage_info
=
usage
response
=
ChatCompletionResponse
(
id
=
request_id
,
created
=
created_time
,
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment