Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
539aa992
Commit
539aa992
authored
Sep 27, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.2' into v0.6.2-dev
parents
93872128
7193774b
Changes
383
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1876 additions
and
1121 deletions
+1876
-1121
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+109
-15
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+39
-8
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+54
-31
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+180
-142
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+75
-0
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+507
-0
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+391
-0
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+5
-4
vllm/engine/protocol.py
vllm/engine/protocol.py
+4
-4
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+0
-1
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+22
-6
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+20
-19
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+200
-38
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+179
-139
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+30
-4
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+5
-0
vllm/entrypoints/openai/rpc/client.py
vllm/entrypoints/openai/rpc/client.py
+0
-451
vllm/entrypoints/openai/rpc/server.py
vllm/entrypoints/openai/rpc/server.py
+0
-237
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+7
-2
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+49
-20
No files found.
vllm/distributed/parallel_state.py
View file @
539aa992
...
@@ -21,11 +21,12 @@ If you only need to use the distributed environment without model/pipeline
...
@@ -21,11 +21,12 @@ If you only need to use the distributed environment without model/pipeline
"""
"""
import
contextlib
import
contextlib
import
pickle
import
pickle
import
weakref
from
collections
import
namedtuple
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
multiprocessing
import
shared_memory
from
multiprocessing
import
shared_memory
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
import
torch
...
@@ -34,6 +35,8 @@ from torch.distributed import Backend, ProcessGroup
...
@@ -34,6 +35,8 @@ from torch.distributed import Backend, ProcessGroup
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
supports_custom_op
@
dataclass
@
dataclass
...
@@ -69,6 +72,59 @@ def _split_tensor_dict(
...
@@ -69,6 +72,59 @@ def _split_tensor_dict(
return
metadata_list
,
tensor_list
return
metadata_list
,
tensor_list
_group_name_counter
:
Dict
[
str
,
int
]
=
{}
def
_get_unique_name
(
name
:
str
)
->
str
:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if
name
not
in
_group_name_counter
:
_group_name_counter
[
name
]
=
0
newname
=
f
"
{
name
}
:
{
_group_name_counter
[
name
]
}
"
_group_name_counter
[
name
]
+=
1
return
newname
_groups
:
Dict
[
str
,
Callable
[[],
"GroupCoordinator"
]]
=
{}
def
_register_group
(
group
:
"GroupCoordinator"
)
->
None
:
# looks like Python 3.8 does not understand `ReferenceType`
_groups
[
group
.
unique_name
]
=
weakref
.
ref
(
group
)
# type: ignore
if
supports_custom_op
():
@
torch
.
library
.
custom_op
(
"vllm::inplace_all_reduce"
,
mutates_args
=
[
"tensor"
])
def
inplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
group
.
_all_reduce
(
tensor
)
@
inplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
return
@
torch
.
library
.
custom_op
(
"vllm::outplace_all_reduce"
,
mutates_args
=
[])
def
outplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce
(
tensor
)
@
outplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
tensor
)
class
GroupCoordinator
:
class
GroupCoordinator
:
"""
"""
PyTorch ProcessGroup wrapper for a group of processes.
PyTorch ProcessGroup wrapper for a group of processes.
...
@@ -111,7 +167,11 @@ class GroupCoordinator:
...
@@ -111,7 +167,11 @@ class GroupCoordinator:
use_custom_allreduce
:
bool
,
use_custom_allreduce
:
bool
,
use_tpu_communicator
:
bool
,
use_tpu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
):
):
group_name
=
group_name
or
"anonymous"
self
.
unique_name
=
_get_unique_name
(
group_name
)
_register_group
(
self
)
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
local_rank
=
local_rank
self
.
local_rank
=
local_rank
...
@@ -134,7 +194,7 @@ class GroupCoordinator:
...
@@ -134,7 +194,7 @@ class GroupCoordinator:
assert
self
.
cpu_group
is
not
None
assert
self
.
cpu_group
is
not
None
assert
self
.
device_group
is
not
None
assert
self
.
device_group
is
not
None
if
torch
.
cuda
.
is_availabl
e
():
if
current_platform
.
is_cuda_alik
e
():
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
else
:
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
...
@@ -149,28 +209,24 @@ class GroupCoordinator:
...
@@ -149,28 +209,24 @@ class GroupCoordinator:
from
vllm.distributed.device_communicators.pynccl
import
(
from
vllm.distributed.device_communicators.pynccl
import
(
PyNcclCommunicator
)
PyNcclCommunicator
)
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
=
None
if
use_pynccl
and
self
.
world_size
>
1
:
if
use_pynccl
and
self
.
world_size
>
1
:
self
.
pynccl_comm
=
PyNcclCommunicator
(
self
.
pynccl_comm
=
PyNcclCommunicator
(
group
=
self
.
cpu_group
,
group
=
self
.
cpu_group
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
else
:
self
.
pynccl_comm
=
None
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
=
None
if
use_custom_allreduce
and
self
.
world_size
>
1
:
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
# Initialize a custom fast all-reduce implementation.
self
.
ca_comm
=
CustomAllreduce
(
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
group
=
self
.
cpu_group
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
else
:
self
.
ca_comm
=
None
from
vllm.distributed.device_communicators.tpu_communicator
import
(
from
vllm.distributed.device_communicators.tpu_communicator
import
(
TpuCommunicator
)
TpuCommunicator
)
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
=
None
if
use_tpu_communicator
and
self
.
world_size
>
1
:
if
use_tpu_communicator
and
self
.
world_size
>
1
:
self
.
tpu_communicator
=
TpuCommunicator
(
group
=
self
.
cpu_group
)
self
.
tpu_communicator
=
TpuCommunicator
(
group
=
self
.
cpu_group
)
...
@@ -264,16 +320,49 @@ class GroupCoordinator:
...
@@ -264,16 +320,49 @@ class GroupCoordinator:
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
if
not
supports_custom_op
():
return
self
.
_all_reduce
(
input_
)
if
self
.
tpu_communicator
is
not
None
and
\
not
self
.
tpu_communicator
.
disabled
:
# TPU handles Dynamo with its own logic.
return
self
.
_all_reduce
(
input_
)
if
self
.
ca_comm
is
not
None
and
self
.
ca_comm
.
should_custom_ar
(
input_
):
return
torch
.
ops
.
vllm
.
outplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
else
:
torch
.
ops
.
vllm
.
inplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
return
input_
def
_all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
The actual all-reduce implementation.
NOTE: This operation will be applied in-place or out-of-place.
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
Always assume this function modifies its input, but use the return
value as the output.
value as the output.
"""
"""
ca_comm
=
self
.
ca_comm
ca_comm
=
self
.
ca_comm
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
# For TPUs, use TPU communicator.
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
...
@@ -758,6 +847,7 @@ def init_world_group(ranks: List[int], local_rank: int,
...
@@ -758,6 +847,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_pynccl
=
False
,
use_pynccl
=
False
,
use_custom_allreduce
=
False
,
use_custom_allreduce
=
False
,
use_tpu_communicator
=
False
,
use_tpu_communicator
=
False
,
group_name
=
"world"
,
)
)
...
@@ -767,6 +857,7 @@ def init_model_parallel_group(
...
@@ -767,6 +857,7 @@ def init_model_parallel_group(
backend
:
str
,
backend
:
str
,
use_custom_allreduce
:
Optional
[
bool
]
=
None
,
use_custom_allreduce
:
Optional
[
bool
]
=
None
,
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
)
->
GroupCoordinator
:
)
->
GroupCoordinator
:
if
use_custom_allreduce
is
None
:
if
use_custom_allreduce
is
None
:
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
...
@@ -778,6 +869,7 @@ def init_model_parallel_group(
...
@@ -778,6 +869,7 @@ def init_model_parallel_group(
use_custom_allreduce
=
use_custom_allreduce
,
use_custom_allreduce
=
use_custom_allreduce
,
use_tpu_communicator
=
True
,
use_tpu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
group_name
=
group_name
,
)
)
...
@@ -931,7 +1023,8 @@ def initialize_model_parallel(
...
@@ -931,7 +1023,8 @@ def initialize_model_parallel(
_TP
=
init_model_parallel_group
(
group_ranks
,
_TP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
get_world_group
().
local_rank
,
backend
,
backend
,
use_message_queue_broadcaster
=
True
)
use_message_queue_broadcaster
=
True
,
group_name
=
"tp"
)
# Build the pipeline model-parallel groups.
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups
:
int
=
(
world_size
//
num_pipeline_model_parallel_groups
:
int
=
(
world_size
//
...
@@ -947,7 +1040,8 @@ def initialize_model_parallel(
...
@@ -947,7 +1040,8 @@ def initialize_model_parallel(
_PP
=
init_model_parallel_group
(
group_ranks
,
_PP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
get_world_group
().
local_rank
,
backend
,
backend
,
use_custom_allreduce
=
False
)
use_custom_allreduce
=
False
,
group_name
=
"pp"
)
def
ensure_model_parallel_initialized
(
def
ensure_model_parallel_initialized
(
...
...
vllm/engine/arg_utils.py
View file @
539aa992
...
@@ -44,22 +44,36 @@ def nullable_str(val: str):
...
@@ -44,22 +44,36 @@ def nullable_str(val: str):
def
nullable_kvs
(
val
:
str
)
->
Optional
[
Mapping
[
str
,
int
]]:
def
nullable_kvs
(
val
:
str
)
->
Optional
[
Mapping
[
str
,
int
]]:
"""Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.
Args:
val: String value to be parsed.
Returns:
Dictionary with parsed values.
"""
if
len
(
val
)
==
0
:
if
len
(
val
)
==
0
:
return
None
return
None
out_dict
:
Dict
[
str
,
int
]
=
{}
out_dict
:
Dict
[
str
,
int
]
=
{}
for
item
in
val
.
split
(
","
):
for
item
in
val
.
split
(
","
):
try
:
kv_parts
=
[
part
.
lower
().
strip
()
for
part
in
item
.
split
(
"="
)]
key
,
value
=
item
.
split
(
"="
)
if
len
(
kv_parts
)
!=
2
:
except
TypeError
as
exc
:
raise
argparse
.
Argument
TypeError
(
msg
=
"Each item should be in the form KEY=VALUE"
"Each item should be in the form KEY=VALUE"
)
raise
ValueError
(
msg
)
from
exc
key
,
value
=
kv_parts
try
:
try
:
out_dict
[
key
]
=
int
(
value
)
parsed_value
=
int
(
value
)
except
ValueError
as
exc
:
except
ValueError
as
exc
:
msg
=
f
"Failed to parse value of item
{
key
}
=
{
value
}
"
msg
=
f
"Failed to parse value of item
{
key
}
=
{
value
}
"
raise
ValueError
(
msg
)
from
exc
raise
argparse
.
ArgumentTypeError
(
msg
)
from
exc
if
key
in
out_dict
and
out_dict
[
key
]
!=
parsed_value
:
raise
argparse
.
ArgumentTypeError
(
f
"Conflicting values specified for key:
{
key
}
"
)
out_dict
[
key
]
=
parsed_value
return
out_dict
return
out_dict
...
@@ -131,6 +145,7 @@ class EngineArgs:
...
@@ -131,6 +145,7 @@ class EngineArgs:
max_cpu_loras
:
Optional
[
int
]
=
None
max_cpu_loras
:
Optional
[
int
]
=
None
device
:
str
=
'auto'
device
:
str
=
'auto'
num_scheduler_steps
:
int
=
1
num_scheduler_steps
:
int
=
1
multi_step_stream_outputs
:
bool
=
False
ray_workers_use_nsight
:
bool
=
False
ray_workers_use_nsight
:
bool
=
False
num_gpu_blocks_override
:
Optional
[
int
]
=
None
num_gpu_blocks_override
:
Optional
[
int
]
=
None
num_lookahead_slots
:
int
=
0
num_lookahead_slots
:
int
=
0
...
@@ -161,6 +176,7 @@ class EngineArgs:
...
@@ -161,6 +176,7 @@ class EngineArgs:
collect_detailed_traces
:
Optional
[
str
]
=
None
collect_detailed_traces
:
Optional
[
str
]
=
None
disable_async_output_proc
:
bool
=
False
disable_async_output_proc
:
bool
=
False
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
...
@@ -458,7 +474,10 @@ class EngineArgs:
...
@@ -458,7 +474,10 @@ class EngineArgs:
default
=
EngineArgs
.
max_seq_len_to_capture
,
default
=
EngineArgs
.
max_seq_len_to_capture
,
help
=
'Maximum sequence length covered by CUDA '
help
=
'Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.'
)
'larger than this, we fall back to eager mode. '
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.'
)
parser
.
add_argument
(
'--disable-custom-all-reduce'
,
parser
.
add_argument
(
'--disable-custom-all-reduce'
,
action
=
'store_true'
,
action
=
'store_true'
,
default
=
EngineArgs
.
disable_custom_all_reduce
,
default
=
EngineArgs
.
disable_custom_all_reduce
,
...
@@ -496,6 +515,12 @@ class EngineArgs:
...
@@ -496,6 +515,12 @@ class EngineArgs:
'e.g.: `image=16,video=2` allows a maximum of 16 '
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to 1 for '
'images and 2 videos per prompt. Defaults to 1 for '
'each modality.'
))
'each modality.'
))
parser
.
add_argument
(
'--mm-processor-kwargs'
,
default
=
None
,
type
=
json
.
loads
,
help
=
(
'Overrides for the multimodal input mapping/processing,'
'e.g., image processor. For example: {"num_crops": 4}.'
))
# LoRA related configs
# LoRA related configs
parser
.
add_argument
(
'--enable-lora'
,
parser
.
add_argument
(
'--enable-lora'
,
...
@@ -571,6 +596,10 @@ class EngineArgs:
...
@@ -571,6 +596,10 @@ class EngineArgs:
help
=
(
'Maximum number of forward steps per '
help
=
(
'Maximum number of forward steps per '
'scheduler call.'
))
'scheduler call.'
))
parser
.
add_argument
(
'--multi-step-stream-outputs'
,
action
=
'store_true'
,
help
=
'If True, then multi-step will stream outputs for every step'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--scheduler-delay-factor'
,
'--scheduler-delay-factor'
,
type
=
float
,
type
=
float
,
...
@@ -805,6 +834,7 @@ class EngineArgs:
...
@@ -805,6 +834,7 @@ class EngineArgs:
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
override_neuron_config
=
self
.
override_neuron_config
,
override_neuron_config
=
self
.
override_neuron_config
,
config_format
=
self
.
config_format
,
config_format
=
self
.
config_format
,
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
)
)
def
create_load_config
(
self
)
->
LoadConfig
:
def
create_load_config
(
self
)
->
LoadConfig
:
...
@@ -974,6 +1004,7 @@ class EngineArgs:
...
@@ -974,6 +1004,7 @@ class EngineArgs:
is_multimodal_model
=
model_config
.
is_multimodal_model
,
is_multimodal_model
=
model_config
.
is_multimodal_model
,
preemption_mode
=
self
.
preemption_mode
,
preemption_mode
=
self
.
preemption_mode
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
multi_step_stream_outputs
=
self
.
multi_step_stream_outputs
,
send_delta_data
=
(
envs
.
VLLM_USE_RAY_SPMD_WORKER
send_delta_data
=
(
envs
.
VLLM_USE_RAY_SPMD_WORKER
and
parallel_config
.
use_ray
),
and
parallel_config
.
use_ray
),
)
)
...
...
vllm/engine/async_llm_engine.py
View file @
539aa992
import
asyncio
import
asyncio
import
time
import
time
import
weakref
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
weakref
import
ReferenceType
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
...
@@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
...
@@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
weak_bind
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
@@ -450,9 +453,6 @@ class AsyncLLMEngine:
...
@@ -450,9 +453,6 @@ class AsyncLLMEngine:
method yields the outputs from the :class:`LLMEngine` to the caller.
method yields the outputs from the :class:`LLMEngine` to the caller.
Args:
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
log_requests: Whether to log the requests.
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
will be automatically started in the generate call.
...
@@ -463,23 +463,22 @@ class AsyncLLMEngine:
...
@@ -463,23 +463,22 @@ class AsyncLLMEngine:
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
def
__init__
(
self
,
def
__init__
(
self
,
worker_use_ray
:
bool
,
*
args
,
*
args
,
log_requests
:
bool
=
True
,
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
self
.
worker_use_ray
=
worker_use_ray
self
.
log_requests
=
log_requests
self
.
log_requests
=
log_requests
self
.
engine
=
self
.
_engine_class
(
*
args
,
**
kwargs
)
self
.
engine
=
self
.
_engine_class
(
*
args
,
**
kwargs
)
# This ensures quick processing of request outputs
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# so the append to asyncio queues is not delayed,
# especially for multi-step.
# especially for multi-step.
#
self
.
use_process_request_outputs_callback
=
(
self
.
use_process_request_outputs_callback
=
True
self
.
engine
.
model_config
.
use_async_output_proc
)
if
self
.
use_process_request_outputs_callback
:
if
self
.
use_process_request_outputs_callback
:
self
.
engine
.
process_request_outputs_callback
=
\
self
.
engine
.
process_request_outputs_callback
=
\
self
.
process_request_outputs
weak_bind
(
self
.
process_request_outputs
)
self
.
background_loop
:
Optional
[
asyncio
.
Future
]
=
None
self
.
background_loop
:
Optional
[
asyncio
.
Future
]
=
None
# We need to keep a reference to unshielded
# We need to keep a reference to unshielded
...
@@ -492,6 +491,11 @@ class AsyncLLMEngine:
...
@@ -492,6 +491,11 @@ class AsyncLLMEngine:
# Lazy initialized fields
# Lazy initialized fields
self
.
_request_tracker
:
RequestTracker
self
.
_request_tracker
:
RequestTracker
def
__del__
(
self
):
if
rt
:
=
getattr
(
self
,
"request_tracker"
,
None
):
# Wake up engine loop so that it will exit cleanly
rt
.
new_requests_event
.
set
()
@
classmethod
@
classmethod
def
_get_executor_cls
(
def
_get_executor_cls
(
cls
,
engine_config
:
EngineConfig
)
->
Type
[
ExecutorAsyncBase
]:
cls
,
engine_config
:
EngineConfig
)
->
Type
[
ExecutorAsyncBase
]:
...
@@ -502,15 +506,12 @@ class AsyncLLMEngine:
...
@@ -502,15 +506,12 @@ class AsyncLLMEngine:
raise
TypeError
(
raise
TypeError
(
"distributed_executor_backend must be a subclass of "
"distributed_executor_backend must be a subclass of "
f
"ExecutorAsyncBase. Got
{
distributed_executor_backend
}
."
)
f
"ExecutorAsyncBase. Got
{
distributed_executor_backend
}
."
)
if
distributed_executor_backend
.
uses_ray
:
# type: ignore
initialize_ray_cluster
(
engine_config
.
parallel_config
)
executor_class
=
distributed_executor_backend
executor_class
=
distributed_executor_backend
elif
engine_config
.
device_config
.
device_type
==
"neuron"
:
elif
engine_config
.
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
if
distributed_executor_backend
==
"ray"
:
if
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_tpu_executor
import
RayTPUExecutorAsync
from
vllm.executor.ray_tpu_executor
import
RayTPUExecutorAsync
executor_class
=
RayTPUExecutorAsync
executor_class
=
RayTPUExecutorAsync
else
:
else
:
...
@@ -531,11 +532,9 @@ class AsyncLLMEngine:
...
@@ -531,11 +532,9 @@ class AsyncLLMEngine:
from
vllm.executor.xpu_executor
import
XPUExecutorAsync
from
vllm.executor.xpu_executor
import
XPUExecutorAsync
executor_class
=
XPUExecutorAsync
executor_class
=
XPUExecutorAsync
elif
distributed_executor_backend
==
"ray"
:
elif
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_xpu_executor
import
RayXPUExecutorAsync
from
vllm.executor.ray_xpu_executor
import
RayXPUExecutorAsync
executor_class
=
RayXPUExecutorAsync
executor_class
=
RayXPUExecutorAsync
elif
distributed_executor_backend
==
"mp"
:
elif
distributed_executor_backend
==
"mp"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.multiproc_xpu_executor
import
(
from
vllm.executor.multiproc_xpu_executor
import
(
MultiprocessingXPUExecutorAsync
)
MultiprocessingXPUExecutorAsync
)
executor_class
=
MultiprocessingXPUExecutorAsync
executor_class
=
MultiprocessingXPUExecutorAsync
...
@@ -543,7 +542,6 @@ class AsyncLLMEngine:
...
@@ -543,7 +542,6 @@ class AsyncLLMEngine:
raise
RuntimeError
(
raise
RuntimeError
(
"Not supported distributed execution model on XPU device."
)
"Not supported distributed execution model on XPU device."
)
elif
distributed_executor_backend
==
"ray"
:
elif
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutorAsync
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutorAsync
executor_class
=
RayGPUExecutorAsync
executor_class
=
RayGPUExecutorAsync
elif
distributed_executor_backend
==
"mp"
:
elif
distributed_executor_backend
==
"mp"
:
...
@@ -559,19 +557,23 @@ class AsyncLLMEngine:
...
@@ -559,19 +557,23 @@ class AsyncLLMEngine:
def
from_engine_args
(
def
from_engine_args
(
cls
,
cls
,
engine_args
:
AsyncEngineArgs
,
engine_args
:
AsyncEngineArgs
,
engine_config
:
Optional
[
EngineConfig
]
=
None
,
start_engine_loop
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"AsyncLLMEngine"
:
)
->
"AsyncLLMEngine"
:
"""Creates an async LLM engine from the engine arguments."""
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
# Create the engine configs.
engine_config
=
engine_args
.
create_engine_config
()
if
engine_config
is
None
:
engine_config
=
engine_args
.
create_engine_config
()
executor_class
=
cls
.
_get_executor_cls
(
engine_config
)
executor_class
=
cls
.
_get_executor_cls
(
engine_config
)
if
executor_class
.
uses_ray
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
# Create the async LLM engine.
# Create the async LLM engine.
engine
=
cls
(
engine
=
cls
(
executor_class
.
uses_ray
,
**
engine_config
.
to_dict
(),
**
engine_config
.
to_dict
(),
executor_class
=
executor_class
,
executor_class
=
executor_class
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_requests
=
not
engine_args
.
disable_log_requests
,
...
@@ -599,9 +601,12 @@ class AsyncLLMEngine:
...
@@ -599,9 +601,12 @@ class AsyncLLMEngine:
return
self
.
_errored_with
is
not
None
return
self
.
_errored_with
is
not
None
@
property
@
property
def
limit_concurrency
(
self
)
->
Optional
[
int
]:
def
dead_error
(
self
)
->
BaseException
:
"""Maximum number of concurrently running requests."""
return
AsyncEngineDeadError
(
return
None
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError)."
)
def
set_errored
(
self
,
exc
:
Exception
)
->
None
:
def
set_errored
(
self
,
exc
:
Exception
)
->
None
:
self
.
_errored_with
=
exc
self
.
_errored_with
=
exc
...
@@ -628,7 +633,7 @@ class AsyncLLMEngine:
...
@@ -628,7 +633,7 @@ class AsyncLLMEngine:
self
.
_request_tracker
=
RequestTracker
()
self
.
_request_tracker
=
RequestTracker
()
self
.
_background_loop_unshielded
=
asyncio
.
get_event_loop
(
self
.
_background_loop_unshielded
=
asyncio
.
get_event_loop
(
).
create_task
(
self
.
run_engine_loop
())
).
create_task
(
self
.
run_engine_loop
(
weakref
.
ref
(
self
)
))
self
.
_background_loop_unshielded
.
add_done_callback
(
self
.
_background_loop_unshielded
.
add_done_callback
(
partial
(
_log_task_completion
,
error_callback
=
self
.
_error_callback
))
partial
(
_log_task_completion
,
error_callback
=
self
.
_error_callback
))
self
.
background_loop
=
asyncio
.
shield
(
self
.
_background_loop_unshielded
)
self
.
background_loop
=
asyncio
.
shield
(
self
.
_background_loop_unshielded
)
...
@@ -698,9 +703,16 @@ class AsyncLLMEngine:
...
@@ -698,9 +703,16 @@ class AsyncLLMEngine:
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
self
.
engine
.
abort_request
(
request_ids
)
self
.
engine
.
abort_request
(
request_ids
)
async
def
run_engine_loop
(
self
):
@
staticmethod
async
def
run_engine_loop
(
engine_ref
:
ReferenceType
):
"""We use a weakref to the engine so that the running loop
doesn't prevent the engine being garbage collected."""
engine
:
Optional
[
"AsyncLLMEngine"
]
=
engine_ref
()
if
not
engine
:
return
pipeline_parallel_size
=
\
pipeline_parallel_size
=
\
self
.
engine
.
parallel_config
.
pipeline_parallel_size
engine
.
engine
.
parallel_config
.
pipeline_parallel_size
has_requests_in_progress
=
[
False
]
*
pipeline_parallel_size
has_requests_in_progress
=
[
False
]
*
pipeline_parallel_size
while
True
:
while
True
:
if
not
any
(
has_requests_in_progress
):
if
not
any
(
has_requests_in_progress
):
...
@@ -711,11 +723,21 @@ class AsyncLLMEngine:
...
@@ -711,11 +723,21 @@ class AsyncLLMEngine:
# timeout, and unblocks the RPC thread in the workers so that
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
# such as add/remove lora adapters.
await
self
.
engine
.
stop_remote_worker_execution_loop_async
()
await
engine
.
engine
.
stop_remote_worker_execution_loop_async
()
await
self
.
_request_tracker
.
wait_for_new_requests
()
request_tracker
=
engine
.
_request_tracker
# Allow engine to be garbage collected while
# waiting for new requests
del
engine
await
asyncio
.
sleep
(
0
)
if
engine_ref
()
is
None
:
return
await
request_tracker
.
wait_for_new_requests
()
engine
=
engine_ref
()
if
not
engine
:
return
logger
.
debug
(
"Got new requests!"
)
logger
.
debug
(
"Got new requests!"
)
requests_in_progress
=
[
requests_in_progress
=
[
asyncio
.
create_task
(
self
.
engine_step
(
ve
))
asyncio
.
create_task
(
engine
.
engine_step
(
ve
))
for
ve
in
range
(
pipeline_parallel_size
)
for
ve
in
range
(
pipeline_parallel_size
)
]
]
has_requests_in_progress
=
[
True
]
*
pipeline_parallel_size
has_requests_in_progress
=
[
True
]
*
pipeline_parallel_size
...
@@ -733,19 +755,20 @@ class AsyncLLMEngine:
...
@@ -733,19 +755,20 @@ class AsyncLLMEngine:
result
=
task
.
result
()
result
=
task
.
result
()
virtual_engine
=
requests_in_progress
.
index
(
task
)
virtual_engine
=
requests_in_progress
.
index
(
task
)
has_unfinished_requests
=
(
has_unfinished_requests
=
(
self
.
engine
.
has_unfinished_requests_for_virtual_engine
(
engine
.
engine
.
has_unfinished_requests_for_virtual_engine
(
virtual_engine
))
virtual_engine
))
if
result
or
has_unfinished_requests
:
if
result
or
has_unfinished_requests
:
requests_in_progress
[
virtual_engine
]
=
(
requests_in_progress
[
virtual_engine
]
=
(
asyncio
.
create_task
(
asyncio
.
create_task
(
self
.
engine_step
(
virtual_engine
)))
engine
.
engine_step
(
virtual_engine
)))
has_requests_in_progress
[
virtual_engine
]
=
True
has_requests_in_progress
[
virtual_engine
]
=
True
else
:
else
:
has_requests_in_progress
[
virtual_engine
]
=
False
has_requests_in_progress
[
virtual_engine
]
=
False
except
asyncio
.
TimeoutError
as
exc
:
except
asyncio
.
TimeoutError
as
exc
:
logger
.
error
(
logger
.
error
(
"Engine iteration timed out. This should never happen!"
)
"Engine iteration timed out. This should never happen!"
)
self
.
set_errored
(
exc
)
engine
.
set_errored
(
exc
)
raise
raise
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
...
@@ -806,7 +829,7 @@ class AsyncLLMEngine:
...
@@ -806,7 +829,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
for generation, if any.
Yields:
Yields:
...
@@ -1022,7 +1045,7 @@ class AsyncLLMEngine:
...
@@ -1022,7 +1045,7 @@ class AsyncLLMEngine:
async
def
start_profile
(
self
)
->
None
:
async
def
start_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# using type instead of isinstance to check to avoid capturing
# inherited classes
# inherited classes
if
type
(
self
.
engine
.
model_executor
)
==
GPUExecutorAsync
:
if
type
(
self
.
engine
.
model_executor
)
==
GPUExecutorAsync
:
# noqa: E721
self
.
engine
.
model_executor
.
start_profile
()
self
.
engine
.
model_executor
.
start_profile
()
else
:
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"start_profile"
)
self
.
engine
.
model_executor
.
_run_workers
(
"start_profile"
)
...
@@ -1030,7 +1053,7 @@ class AsyncLLMEngine:
...
@@ -1030,7 +1053,7 @@ class AsyncLLMEngine:
async
def
stop_profile
(
self
)
->
None
:
async
def
stop_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# using type instead of isinstance to check to avoid capturing
# inherited classes
# inherited classes
if
type
(
self
.
engine
.
model_executor
)
==
GPUExecutorAsync
:
if
type
(
self
.
engine
.
model_executor
)
==
GPUExecutorAsync
:
# noqa: E721
self
.
engine
.
model_executor
.
stop_profile
()
self
.
engine
.
model_executor
.
stop_profile
()
else
:
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"stop_profile"
)
self
.
engine
.
model_executor
.
_run_workers
(
"stop_profile"
)
vllm/engine/llm_engine.py
View file @
539aa992
import
functools
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
...
@@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
...
@@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.utils
import
Counter
,
Device
from
vllm.utils
import
Counter
,
Device
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -95,7 +95,7 @@ class OutputData(NamedTuple):
...
@@ -95,7 +95,7 @@ class OutputData(NamedTuple):
class
SchedulerContext
:
class
SchedulerContext
:
def
__init__
(
self
):
def
__init__
(
self
,
multi_step_stream_outputs
:
bool
=
False
):
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
EmbeddingRequestOutput
]]
=
[]
...
@@ -103,6 +103,8 @@ class SchedulerContext:
...
@@ -103,6 +103,8 @@ class SchedulerContext:
List
[
SequenceGroupMetadata
]]
=
None
List
[
SequenceGroupMetadata
]]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
self
.
multi_step_stream_outputs
:
bool
=
multi_step_stream_outputs
def
append_output
(
self
,
outputs
:
List
[
SamplerOutput
],
def
append_output
(
self
,
outputs
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduler_outputs
:
SchedulerOutputs
,
is_async
:
bool
,
scheduler_outputs
:
SchedulerOutputs
,
is_async
:
bool
,
...
@@ -144,7 +146,7 @@ class LLMEngine:
...
@@ -144,7 +146,7 @@ class LLMEngine:
decoding.
decoding.
executor_class: The model executor class for managing distributed
executor_class: The model executor class for managing distributed
execution.
execution.
prompt_adapter_config (Optional): The configuration related to serving
prompt_adapter_config (Optional): The configuration related to serving
prompt adapters.
prompt adapters.
log_stats: Whether to log statistics.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection.
usage_context: Specified entry point, used for usage info collection.
...
@@ -219,6 +221,7 @@ class LLMEngine:
...
@@ -219,6 +221,7 @@ class LLMEngine:
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
)
->
None
:
)
->
None
:
logger
.
info
(
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
"Initializing an LLM engine (v%s) with config: "
...
@@ -234,8 +237,9 @@ class LLMEngine:
...
@@ -234,8 +237,9 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
"use_async_output_proc=%s)"
,
"enable_prefix_caching=%s, use_async_output_proc=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)"
,
VLLM_VERSION
,
VLLM_VERSION
,
model_config
.
model
,
model_config
.
model
,
speculative_config
,
speculative_config
,
...
@@ -266,8 +270,11 @@ class LLMEngine:
...
@@ -266,8 +270,11 @@ class LLMEngine:
model_config
.
served_model_name
,
model_config
.
served_model_name
,
scheduler_config
.
use_v2_block_manager
,
scheduler_config
.
use_v2_block_manager
,
scheduler_config
.
num_scheduler_steps
,
scheduler_config
.
num_scheduler_steps
,
scheduler_config
.
multi_step_stream_outputs
,
cache_config
.
enable_prefix_caching
,
cache_config
.
enable_prefix_caching
,
model_config
.
use_async_output_proc
,
model_config
.
use_async_output_proc
,
use_cached_outputs
,
model_config
.
mm_processor_kwargs
,
)
)
# TODO(woosuk): Print more configs in debug mode.
# TODO(woosuk): Print more configs in debug mode.
from
vllm.plugins
import
load_general_plugins
from
vllm.plugins
import
load_general_plugins
...
@@ -286,6 +293,7 @@ class LLMEngine:
...
@@ -286,6 +293,7 @@ class LLMEngine:
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
)
)
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
if
not
self
.
model_config
.
skip_tokenizer_init
:
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
tokenizer
=
self
.
_init_tokenizer
()
...
@@ -327,136 +335,134 @@ class LLMEngine:
...
@@ -327,136 +335,134 @@ class LLMEngine:
observability_config
=
self
.
observability_config
,
observability_config
=
self
.
observability_config
,
)
)
init_success
=
False
if
not
self
.
model_config
.
embedding_mode
:
try
:
self
.
_initialize_kv_caches
()
if
not
self
.
model_config
.
embedding_mode
:
self
.
_initialize_kv_caches
()
# If usage stat is enabled, collect relevant info.
if
is_usage_stats_enabled
():
# If usage stat is enabled, collect relevant info.
from
vllm.model_executor.model_loader
import
(
if
is_usage_stats_enabled
():
get_architecture_class_name
)
from
vllm.model_executor.model_loader
import
(
usage_message
.
report_usage
(
get_architecture_class_name
)
get_architecture_class_name
(
model_config
),
usage_message
.
report_usage
(
usage_context
,
get_architecture_class_name
(
model_config
),
extra_kvs
=
{
usage_context
,
# Common configuration
extra_kvs
=
{
"dtype"
:
# Common configuration
str
(
model_config
.
dtype
),
"dtype"
:
"tensor_parallel_size"
:
str
(
model_config
.
dtype
),
parallel_config
.
tensor_parallel_size
,
"tensor_parallel_size"
:
"block_size"
:
parallel_config
.
tensor_parallel_size
,
cache_config
.
block_size
,
"block_size"
:
"gpu_memory_utilization"
:
cache_config
.
block_size
,
cache_config
.
gpu_memory_utilization
,
"gpu_memory_utilization"
:
cache_config
.
gpu_memory_utilization
,
# Quantization
"quantization"
:
# Quantization
model_config
.
quantization
,
"quantization"
:
"kv_cache_dtype"
:
model_config
.
quantization
,
str
(
cache_config
.
cache_dtype
),
"kv_cache_dtype"
:
str
(
cache_config
.
cache_dtype
),
# Feature flags
"enable_lora"
:
# Feature flags
bool
(
lora_config
),
"enable_lora"
:
"enable_prompt_adapter"
:
bool
(
lora_config
),
bool
(
prompt_adapter_config
),
"enable_prompt_adapter"
:
"enable_prefix_caching"
:
bool
(
prompt_adapter_config
),
cache_config
.
enable_prefix_caching
,
"enable_prefix_caching"
:
"enforce_eager"
:
cache_config
.
enable_prefix_caching
,
model_config
.
enforce_eager
,
"enforce_eager"
:
"disable_custom_all_reduce"
:
model_config
.
enforce_eager
,
parallel_config
.
disable_custom_all_reduce
,
"disable_custom_all_reduce"
:
})
parallel_config
.
disable_custom_all_reduce
,
})
if
self
.
tokenizer
:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self
.
tokenizer
.
ping
()
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
scheduler_contexts
=
[
if
self
.
tokenizer
:
SchedulerContext
()
# Ping the tokenizer to ensure liveness if it runs in a
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
# different process.
]
self
.
tokenizer
.
ping
()
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
scheduler_contexts
=
[
SchedulerContext
(
multi_step_stream_outputs
=
self
.
scheduler_config
.
multi_step_stream_outputs
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
if
model_config
.
use_async_output_proc
:
process_model_outputs
=
weak_bind
(
self
.
_process_model_outputs
)
self
.
async_callbacks
=
[
self
.
async_callbacks
=
[
functools
.
partial
(
self
.
_
process_model_outputs
,
partial
(
process_model_outputs
,
ctx
=
self
.
scheduler_contexts
[
v_id
])
ctx
=
self
.
scheduler_contexts
[
v_id
])
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
]
else
:
self
.
async_callbacks
=
[]
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self
.
process_request_outputs_callback
:
Optional
[
Callable
]
=
None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self
.
scheduler
=
[
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
,
parallel_config
.
pipeline_parallel_size
,
self
.
async_callbacks
[
v_id
]
if
model_config
.
use_async_output_proc
else
None
)
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
]
# Currently used by AsyncLLMEngine to ensure quick append
# Metric Logging.
# of request outputs to asyncio queues
if
self
.
log_stats
:
self
.
process_request_outputs_callback
=
None
if
stat_loggers
is
not
None
:
self
.
stat_loggers
=
stat_loggers
# Create the scheduler.
else
:
# NOTE: the cache_config here have been updated with the numbers of
# Lazy import for prometheus multiprocessing.
# GPU and CPU blocks, which are profiled in the distributed executor.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
self
.
scheduler
=
[
# before prometheus_client is imported.
Scheduler
(
# See https://prometheus.github.io/client_python/multiprocess/
scheduler_config
,
cache_config
,
lora_config
,
from
vllm.engine.metrics
import
(
LoggingStatLogger
,
parallel_config
.
pipeline_parallel_size
,
PrometheusStatLogger
)
self
.
async_callbacks
[
v_id
]
if
model_config
.
use_async_output_proc
else
None
)
self
.
stat_loggers
=
{
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
"logging"
:
]
LoggingStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
),
# Metric Logging.
"prometheus"
:
if
self
.
log_stats
:
PrometheusStatLogger
(
if
stat_loggers
is
not
None
:
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
self
.
stat_loggers
=
stat_loggers
labels
=
dict
(
model_name
=
model_config
.
served_model_name
),
else
:
max_model_len
=
self
.
model_config
.
max_model_len
),
# Lazy import for prometheus multiprocessing.
}
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
self
.
stat_loggers
[
"prometheus"
].
info
(
"cache_config"
,
# before prometheus_client is imported.
self
.
cache_config
)
# See https://prometheus.github.io/client_python/multiprocess/
from
vllm.engine.metrics
import
(
LoggingStatLogger
,
self
.
tracer
=
None
PrometheusStatLogger
)
if
self
.
observability_config
.
otlp_traces_endpoint
:
self
.
tracer
=
init_tracer
(
self
.
stat_loggers
=
{
"vllm.llm_engine"
,
"logging"
:
self
.
observability_config
.
otlp_traces_endpoint
)
LoggingStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
),
# Create sequence output processor, e.g. for beam search or
"prometheus"
:
# speculative decoding.
PrometheusStatLogger
(
self
.
output_processor
=
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
SequenceGroupOutputProcessor
.
create_output_processor
(
labels
=
dict
(
model_name
=
model_config
.
served_model_name
),
self
.
scheduler_config
,
max_model_len
=
self
.
model_config
.
max_model_len
),
self
.
detokenizer
,
}
self
.
scheduler
,
self
.
stat_loggers
[
"prometheus"
].
info
(
"cache_config"
,
self
.
seq_counter
,
self
.
cache_config
)
get_tokenizer_for_seq
,
stop_checker
=
StopChecker
(
self
.
tracer
=
None
self
.
scheduler_config
.
max_model_len
,
if
self
.
observability_config
.
otlp_traces_endpoint
:
self
.
tracer
=
init_tracer
(
"vllm.llm_engine"
,
self
.
observability_config
.
otlp_traces_endpoint
)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self
.
output_processor
=
(
SequenceGroupOutputProcessor
.
create_output_processor
(
self
.
scheduler_config
,
self
.
detokenizer
,
self
.
scheduler
,
self
.
seq_counter
,
get_tokenizer_for_seq
,
get_tokenizer_for_seq
,
stop_checker
=
StopChecker
(
),
self
.
scheduler_config
.
max_model_len
,
))
get_tokenizer_for_seq
,
),
))
init_success
=
True
finally
:
if
not
init_success
:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self
.
model_executor
.
shutdown
()
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
@@ -625,6 +631,7 @@ class LLMEngine:
...
@@ -625,6 +631,7 @@ class LLMEngine:
lora_request
:
Optional
[
LoRARequest
],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
None
:
)
->
None
:
self
.
_validate_model_inputs
(
processed_inputs
)
self
.
_validate_model_inputs
(
processed_inputs
)
# Create the sequences.
# Create the sequences.
...
@@ -655,7 +662,8 @@ class LLMEngine:
...
@@ -655,7 +662,8 @@ class LLMEngine:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
elif
isinstance
(
params
,
PoolingParams
):
elif
isinstance
(
params
,
PoolingParams
):
seq_group
=
self
.
_create_sequence_group_with_pooling
(
seq_group
=
self
.
_create_sequence_group_with_pooling
(
request_id
,
request_id
,
...
@@ -664,7 +672,8 @@ class LLMEngine:
...
@@ -664,7 +672,8 @@ class LLMEngine:
arrival_time
=
arrival_time
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Either SamplingParams or PoolingParams must be provided."
)
"Either SamplingParams or PoolingParams must be provided."
)
...
@@ -689,6 +698,7 @@ class LLMEngine:
...
@@ -689,6 +698,7 @@ class LLMEngine:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
)
->
None
:
"""Add a request to the engine's request pool.
"""Add a request to the engine's request pool.
...
@@ -707,6 +717,8 @@ class LLMEngine:
...
@@ -707,6 +717,8 @@ class LLMEngine:
arrival_time: The arrival time of the request. If None, we use
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
the current monotonic time.
trace_headers: OpenTelemetry trace headers.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Details:
Details:
- Set arrival_time to the current time if it is None.
- Set arrival_time to the current time if it is None.
...
@@ -735,6 +747,11 @@ class LLMEngine:
...
@@ -735,6 +747,11 @@ class LLMEngine:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
"not enabled!"
)
if
priority
>
0
and
not
self
.
scheduler_config
.
policy
==
"priority"
:
raise
ValueError
(
f
"Got priority
{
priority
}
but "
"Priority scheduling is not enabled."
)
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
...
@@ -754,6 +771,7 @@ class LLMEngine:
...
@@ -754,6 +771,7 @@ class LLMEngine:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
priority
,
)
)
def
_create_sequence_group_with_sampling
(
def
_create_sequence_group_with_sampling
(
...
@@ -766,6 +784,7 @@ class LLMEngine:
...
@@ -766,6 +784,7 @@ class LLMEngine:
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
priority
:
int
=
0
,
)
->
SequenceGroup
:
)
->
SequenceGroup
:
"""Creates a SequenceGroup with SamplingParams."""
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs
=
self
.
get_model_config
().
max_logprobs
max_logprobs
=
self
.
get_model_config
().
max_logprobs
...
@@ -792,7 +811,8 @@ class LLMEngine:
...
@@ -792,7 +811,8 @@ class LLMEngine:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
return
seq_group
return
seq_group
...
@@ -805,6 +825,7 @@ class LLMEngine:
...
@@ -805,6 +825,7 @@ class LLMEngine:
lora_request
:
Optional
[
LoRARequest
],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
encoder_seq
:
Optional
[
Sequence
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
priority
:
int
=
0
,
)
->
SequenceGroup
:
)
->
SequenceGroup
:
"""Creates a SequenceGroup with PoolingParams."""
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
# Defensive copy of PoolingParams, which are used by the pooler
...
@@ -817,7 +838,8 @@ class LLMEngine:
...
@@ -817,7 +838,8 @@ class LLMEngine:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
pooling_params
=
pooling_params
,
pooling_params
=
pooling_params
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
return
seq_group
return
seq_group
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
...
@@ -877,8 +899,8 @@ class LLMEngine:
...
@@ -877,8 +899,8 @@ class LLMEngine:
"""
"""
return
self
.
scheduler
[
virtual_engine
].
has_unfinished_seqs
()
return
self
.
scheduler
[
virtual_engine
].
has_unfinished_seqs
()
@
staticmethod
def
_process_sequence_group_outputs
(
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
EmbeddingSequenceGroupOutput
],
outputs
:
List
[
EmbeddingSequenceGroupOutput
],
)
->
None
:
)
->
None
:
...
@@ -1001,7 +1023,8 @@ class LLMEngine:
...
@@ -1001,7 +1023,8 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
...
@@ -1022,8 +1045,8 @@ class LLMEngine:
...
@@ -1022,8 +1045,8 @@ class LLMEngine:
for
scheduler
in
self
.
scheduler
:
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
scheduler
.
free_finished_seq_groups
()
# For multi-step, do
no
t create outputs each iteration
# For multi-step
without streaming
, do
n'
t create outputs each iteration
if
not
is_last_step
:
if
not
is_last_step
and
not
ctx
.
multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
# Immediately process request outputs here (if callback is given)
if
(
finished_now
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
and
self
.
process_request_outputs_callback
is
not
None
):
...
@@ -1040,17 +1063,27 @@ class LLMEngine:
...
@@ -1040,17 +1063,27 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
# For multi-step with streaming, create outputs each iteration
if
not
is_last_step
and
ctx
.
multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
if
self
.
process_request_outputs_callback
is
not
None
:
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
return
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
params
=
seq_group
.
sampling_params
params
=
seq_group
.
sampling_params
if
params
is
not
None
and
params
.
output_kind
==
(
if
params
is
not
None
and
params
.
output_kind
==
(
RequestOutputKind
.
DELTA
)
and
not
seq_group
.
is_finished
():
RequestOutputKind
.
DELTA
)
and
not
seq_group
.
is_finished
():
continue
continue
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
...
@@ -1292,6 +1325,7 @@ class LLMEngine:
...
@@ -1292,6 +1325,7 @@ class LLMEngine:
# torch.distributed ops which may otherwise timeout, and unblocks
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
# queued control plane messages, such as add/remove lora adapters.
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
ctx
.
request_outputs
return
ctx
.
request_outputs
...
@@ -1608,7 +1642,7 @@ class LLMEngine:
...
@@ -1608,7 +1642,7 @@ class LLMEngine:
def
start_profile
(
self
)
->
None
:
def
start_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
# inherited classes (MultiprocessingGPUExecutor)
if
type
(
self
.
model_executor
)
==
GPUExecutor
:
if
type
(
self
.
model_executor
)
==
GPUExecutor
:
# noqa: E721
self
.
model_executor
.
start_profile
()
self
.
model_executor
.
start_profile
()
else
:
else
:
self
.
model_executor
.
_run_workers
(
"start_profile"
)
self
.
model_executor
.
_run_workers
(
"start_profile"
)
...
@@ -1616,7 +1650,7 @@ class LLMEngine:
...
@@ -1616,7 +1650,7 @@ class LLMEngine:
def
stop_profile
(
self
)
->
None
:
def
stop_profile
(
self
)
->
None
:
# using type instead of isinstance to check to avoid capturing
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
# inherited classes (MultiprocessingGPUExecutor)
if
type
(
self
.
model_executor
)
==
GPUExecutor
:
if
type
(
self
.
model_executor
)
==
GPUExecutor
:
# noqa: E721
self
.
model_executor
.
stop_profile
()
self
.
model_executor
.
stop_profile
()
else
:
else
:
self
.
model_executor
.
_run_workers
(
"stop_profile"
)
self
.
model_executor
.
_run_workers
(
"stop_profile"
)
...
@@ -1700,7 +1734,11 @@ class LLMEngine:
...
@@ -1700,7 +1734,11 @@ class LLMEngine:
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
LLMInputs
,
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
]):
EncoderDecoderLLMInputs
]):
if
self
.
is_encoder_decoder_model
():
if
self
.
model_config
.
is_multimodal_model
:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_ids
=
inputs
.
get
(
"prompt_token_ids"
)
elif
self
.
is_encoder_decoder_model
():
prompt_ids
=
inputs
.
get
(
"encoder_prompt_token_ids"
)
prompt_ids
=
inputs
.
get
(
"encoder_prompt_token_ids"
)
else
:
else
:
prompt_ids
=
inputs
.
get
(
"prompt_token_ids"
)
prompt_ids
=
inputs
.
get
(
"prompt_token_ids"
)
...
...
vllm/en
trypoints/openai/rpc
/__init__.py
→
vllm/en
gine/multiprocessing
/__init__.py
View file @
539aa992
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Mapping
,
Optional
,
Union
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
vllm
import
PoolingParams
from
vllm.inputs
import
PromptInputs
from
vllm.inputs
import
PromptInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
IPC_INPUT_EXT
=
"_input_socket"
VLLM_RPC_SOCKET_LIMIT_CUTOFF
=
2000
IPC_OUTPUT_EXT
=
"_output_socket"
IPC_HEALTH_EXT
=
"_health_socket"
IPC_DATA_EXT
=
"_data_socket"
# HWM is set to Infinity.
VLLM_RPC_ZMQ_HWM
=
0
class
MQEngineDeadError
(
RuntimeError
):
pass
@
dataclass
@
dataclass
class
RPC
Generate
Request
:
class
RPC
Process
Request
:
inputs
:
PromptInputs
inputs
:
PromptInputs
s
ampling
_p
arams
:
Samp
lingParams
params
:
Union
[
S
ampling
P
arams
,
Poo
lingParams
]
request_id
:
str
request_id
:
str
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
@
dataclass
class
RPCError
:
request_id
:
Optional
[
str
]
is_engine_errored
:
bool
exception
:
BaseException
@
dataclass
@
dataclass
class
RPCAbortRequest
:
class
RPCAbortRequest
:
request_id
:
str
request_id
:
str
class
RPC
Utility
Request
(
Enum
):
class
RPC
Startup
Request
(
Enum
):
IS_SERVER_READY
=
1
IS_SERVER_READY
=
1
GET_MODEL_CONFIG
=
2
GET_DECODING_CONFIG
=
3
GET_PARALLEL_CONFIG
=
4
@
dataclass
GET_SCHEDULER_CONFIG
=
5
class
RPCStartupResponse
:
GET_LORA_CONFIG
=
6
tracing_enabled
:
bool
DO_LOG_STATS
=
7
IS_SERVER_HEALTHY
=
8
IS_TRACING_ENABLED
=
9
class
RPCUProfileRequest
(
Enum
):
START_PROFILE
=
10
START_PROFILE
=
1
STOP_PROFILE
=
11
STOP_PROFILE
=
2
RPC_REQUEST_TYPE
=
Union
[
RPCGenerateRequest
,
RPCAbortRequest
,
RPC_REQUEST_T
=
Union
[
RPCProcessRequest
,
RPCAbortRequest
,
RPCStartupRequest
,
RPCUtilityRequest
]
RPCUProfileRequest
]
REQUEST_OUTPUTS_T
=
Union
[
List
[
RequestOutput
],
RPCError
]
def
ENGINE_DEAD_ERROR
(
error
:
Optional
[
BaseException
]
=
None
)
->
MQEngineDeadError
:
if
error
is
None
:
return
MQEngineDeadError
(
"Engine loop is not running. Inspect the stacktrace to "
"find the original error"
)
return
MQEngineDeadError
(
"Engine loop is not running. Inspect the stacktrace to "
f
"find the original error:
{
repr
(
error
)
}
."
)
vllm/engine/multiprocessing/client.py
0 → 100644
View file @
539aa992
import
asyncio
import
copy
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Union
)
import
cloudpickle
import
zmq
import
zmq.asyncio
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
vllm
import
PoolingParams
from
vllm.config
import
DecodingConfig
,
EngineConfig
,
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.engine.multiprocessing
import
(
ENGINE_DEAD_ERROR
,
IPC_DATA_EXT
,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
RPC_REQUEST_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCProcessRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
logger
=
init_logger
(
__name__
)
class
MQClientClosedError
(
Exception
):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class
MQLLMEngineClient
:
"""A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol.
MQLLMEngine and MQLLMEngineClient are intended to run in separate
processes communicating via zeromq ipc sockets.
The entrypoint to MQLLMEngineClient is through the generate()
method. On generate() MQLLMEngine does three things:
- Creates an asyncio output queue
- Sends a RPCGenerateRequest to the MQLLMEngine via zmq
- Pulls RequestOutputs from its queue and yields them
MQLLMEngine runs two background loops:
- output_loop: the output loop pulls List[RequestOutput]
from the MQLLMEngine via zmq (each list is the output
of one engine_step in the LLMEngine). It then parses
the list and pushes individual request_outputs into
the corresponding output_queue such that they can be
consumed by the .generate() method.
- health_loop: the health loop queries the health socket
every N seconds, confirming the engine is healthy
"""
def
__init__
(
self
,
ipc_path
:
str
,
engine_config
:
EngineConfig
):
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
# Get the configs.
self
.
model_config
=
engine_config
.
model_config
self
.
decoding_config
=
engine_config
.
decoding_config
# Create the tokenizer group.
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
self
.
model_config
,
scheduler_config
=
engine_config
.
scheduler_config
,
parallel_config
=
engine_config
.
parallel_config
,
enable_lora
=
bool
(
engine_config
.
lora_config
),
)
# Send RPCGenerateRequest to the MQLLMEngine.
self
.
input_socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
PUSH
)
self
.
input_socket
.
connect
(
f
"
{
ipc_path
}{
IPC_INPUT_EXT
}
"
)
# Receive streams of RequestOutput from the MQLLMEngine.
self
.
output_socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
PULL
)
self
.
output_socket
.
connect
(
f
"
{
ipc_path
}{
IPC_OUTPUT_EXT
}
"
)
# IPC path for acking heartbeats.
self
.
heartbeat_socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
PULL
)
self
.
heartbeat_socket
.
connect
(
f
"
{
ipc_path
}{
IPC_HEALTH_EXT
}
"
)
# IPC path for the data socket.
self
.
data_ipc_path
=
f
"
{
ipc_path
}{
IPC_DATA_EXT
}
"
# Stream for each individual request.
self
.
output_queues
:
Dict
[
str
,
asyncio
.
Queue
]
=
{}
self
.
output_loop
=
asyncio
.
create_task
(
self
.
run_output_handler_loop
())
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self
.
health_loop
:
Optional
[
asyncio
.
Task
]
=
None
@
staticmethod
def
is_unsupported_config
(
engine_args
:
AsyncEngineArgs
):
# Pipeline parallel not yet supported
return
engine_args
.
pipeline_parallel_size
>
1
@
contextmanager
def
get_data_socket
(
self
)
->
Iterator
[
Socket
]:
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
try
:
socket
.
connect
(
self
.
data_ipc_path
)
yield
socket
finally
:
socket
.
close
(
linger
=
0
)
async
def
run_heartbeat_loop
(
self
,
timeout
:
int
):
"""Background loop that continually listens to the RPCServer for
heartbeats.
"""
try
:
while
True
:
if
await
self
.
heartbeat_socket
.
poll
(
timeout
=
timeout
)
==
0
:
# No heartbeat was received. Set error and exit the loop
self
.
_set_errored
(
TimeoutError
(
"No heartbeat received "
"from MQLLMEngine"
))
logger
.
debug
(
"Shutting down MQLLMEngineClient check "
"health loop due to timeout"
)
break
else
:
# Heartbeat received- check the message
await
self
.
_check_success
(
error_message
=
"Heartbeat failed."
,
socket
=
self
.
heartbeat_socket
)
logger
.
debug
(
"Heartbeat successful."
)
except
asyncio
.
CancelledError
:
logger
.
debug
(
"Shutting down MQLLMEngineClient check health loop."
)
except
Exception
as
e
:
self
.
_set_errored
(
e
)
async
def
run_output_handler_loop
(
self
):
"""Get RequestOutputs from Engine and stream to Request Queues"""
try
:
while
True
:
# Poll, checking for ENGINE_DEAD
while
await
self
.
output_socket
.
poll
(
timeout
=
VLLM_RPC_TIMEOUT
)
==
0
:
logger
.
debug
(
"Waiting for output from MQLLMEngine."
)
# If errored, alert all running requests.
if
self
.
errored
:
for
queue_j
in
tuple
(
self
.
output_queues
.
values
()):
queue_j
.
put_nowait
(
ENGINE_DEAD_ERROR
(
self
.
_errored_with
))
return
message
:
Frame
=
await
self
.
output_socket
.
recv
(
copy
=
False
)
request_outputs
=
pickle
.
loads
(
message
.
buffer
)
is_error
=
isinstance
(
request_outputs
,
(
BaseException
,
RPCError
))
if
is_error
:
if
isinstance
(
request_outputs
,
RPCError
):
rpc_error
:
RPCError
=
request_outputs
request_id
=
rpc_error
.
request_id
exception
=
rpc_error
.
exception
is_engine_errored
=
rpc_error
.
is_engine_errored
else
:
# MPLLMEngine should always return an RPCError to
# the output_socket when an issue arises.
# If we are here, we are in a bad state and
# should shut down the server.
error
:
BaseException
=
request_outputs
logger
.
error
(
"Received Exception %s rather than RPCError from "
"MPLLMEngine. This should never happen."
,
error
)
request_id
=
None
exception
=
error
is_engine_errored
=
True
# Set to error state only on engine critical error
# (and record only the first one)
if
is_engine_errored
and
not
self
.
_errored_with
:
self
.
_errored_with
=
exception
if
request_id
is
None
:
for
queue_i
in
tuple
(
self
.
output_queues
.
values
()):
queue_i
.
put_nowait
(
exception
)
else
:
queue
=
self
.
output_queues
.
get
(
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
exception
)
else
:
# Put each output into the appropriate steam.
for
request_output
in
request_outputs
:
queue
=
self
.
output_queues
.
get
(
request_output
.
request_id
)
if
queue
is
not
None
:
queue
.
put_nowait
(
request_output
)
except
asyncio
.
CancelledError
:
logger
.
debug
(
"Shutting down MQLLMEngineClient output handler."
)
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
with
self
.
get_data_socket
()
as
socket
:
# Wait until server is ready.
response
=
await
self
.
_wait_for_server_rpc
(
socket
)
self
.
tracing_flag
=
response
.
tracing_enabled
# Start health_loop.
self
.
health_loop
=
asyncio
.
create_task
(
self
.
run_heartbeat_loop
(
timeout
=
VLLM_RPC_TIMEOUT
))
def
close
(
self
):
"""Destroy the ZeroMQ Context."""
# Close all sockets and terminate the context.
self
.
context
.
destroy
(
linger
=
0
)
# Cancel background tasks.
if
self
.
health_loop
is
not
None
:
self
.
health_loop
.
cancel
()
self
.
output_loop
.
cancel
()
def
_set_errored
(
self
,
e
:
BaseException
):
logger
.
exception
(
repr
(
e
))
if
self
.
_errored_with
is
None
:
self
.
_errored_with
=
e
@
staticmethod
async
def
_send_get_data_rpc_request
(
request
:
RPCStartupRequest
,
expected_type
:
Any
,
error_message
:
str
,
socket
:
Socket
)
->
Any
:
"""Send an RPC request that is expecting data back."""
# Ping RPCServer with a request.
await
socket
.
send_multipart
((
pickle
.
dumps
(
request
),
),
copy
=
False
)
# Make sure the server responds in time.
if
await
socket
.
poll
(
timeout
=
VLLM_RPC_TIMEOUT
)
==
0
:
raise
TimeoutError
(
"RPCServer didn't reply within "
f
"
{
VLLM_RPC_TIMEOUT
}
ms"
)
# Await the data from the Server.
frame
=
await
socket
.
recv
(
copy
=
False
)
data
=
pickle
.
loads
(
frame
.
buffer
)
if
isinstance
(
data
,
BaseException
):
raise
data
elif
not
isinstance
(
data
,
expected_type
):
raise
ValueError
(
error_message
)
return
data
@
staticmethod
async
def
_send_one_way_rpc_request
(
request
:
RPC_REQUEST_T
,
socket
:
Socket
):
"""Send one-way RPC request to trigger an action."""
if
socket
.
closed
:
raise
MQClientClosedError
()
await
socket
.
send_multipart
((
pickle
.
dumps
(
request
),
))
async
def
_await_ack
(
self
,
error_message
:
str
,
socket
:
Socket
):
"""Await acknowledgement that a request succeeded."""
if
socket
.
closed
:
raise
MQClientClosedError
()
if
await
socket
.
poll
(
timeout
=
VLLM_RPC_TIMEOUT
)
==
0
:
raise
TimeoutError
(
"MQLLMEngine didn't reply within "
f
"
{
VLLM_RPC_TIMEOUT
}
ms"
)
await
self
.
_check_success
(
error_message
,
socket
)
@
staticmethod
async
def
_check_success
(
error_message
:
str
,
socket
:
Socket
):
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
if
socket
.
closed
:
raise
MQClientClosedError
()
frame
=
await
socket
.
recv
(
copy
=
False
)
response
=
pickle
.
loads
(
frame
.
buffer
)
# Raise error if unsuccessful
if
isinstance
(
response
,
BaseException
):
raise
response
elif
(
not
isinstance
(
response
,
str
)
or
response
!=
VLLM_RPC_SUCCESS_STR
):
raise
ValueError
(
error_message
)
async
def
get_tokenizer
(
self
,
lora_request
:
LoRARequest
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
return
self
.
decoding_config
async
def
get_model_config
(
self
)
->
ModelConfig
:
return
self
.
model_config
async
def
is_tracing_enabled
(
self
)
->
bool
:
return
self
.
tracing_flag
async
def
_wait_for_server_rpc
(
self
,
socket
:
Socket
)
->
RPCStartupResponse
:
"""Wait for the RPCServer to start up."""
return
await
self
.
_send_get_data_rpc_request
(
request
=
RPCStartupRequest
.
IS_SERVER_READY
,
expected_type
=
RPCStartupResponse
,
error_message
=
"Unable to start RPC Server"
,
socket
=
socket
)
async
def
abort
(
self
,
request_id
:
str
):
"""Send an ABORT_REQUEST signal to the RPC Server"""
with
suppress
(
MQClientClosedError
):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCAbortRequest
(
request_id
),
socket
=
self
.
input_socket
)
async
def
do_log_stats
(
self
):
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
pass
async
def
check_health
(
self
):
"""
The check health loop probes the health status of the
Engine's health every N seconds and sets _errored_with
if the engine is unhealthy.
"""
if
self
.
_errored_with
is
not
None
:
raise
self
.
_errored_with
@
property
def
is_running
(
self
)
->
bool
:
return
not
self
.
errored
@
property
def
is_stopped
(
self
)
->
bool
:
return
self
.
errored
@
property
def
errored
(
self
)
->
bool
:
return
self
.
_errored_with
is
not
None
@
property
def
dead_error
(
self
)
->
BaseException
:
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
def
generate
(
self
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return
self
.
_process_request
(
inputs
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
prompt_adapter_request
)
def
encode
(
self
,
inputs
:
PromptInputs
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return
self
.
_process_request
(
inputs
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
)
async
def
_process_request
(
self
,
inputs
:
PromptInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
Union
[
AsyncGenerator
[
RequestOutput
,
None
],
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.
if
self
.
_errored_with
is
not
None
:
raise
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
# 1) Create output queue for this requests.
queue
:
asyncio
.
Queue
[
Union
[
RequestOutput
,
BaseException
]]
=
asyncio
.
Queue
()
self
.
output_queues
[
request_id
]
=
queue
try
:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if
isinstance
(
params
,
SamplingParams
)
and
params
.
logits_processors
:
# Defensive shallow copy
params
=
copy
.
copy
(
params
)
logits_processors
=
params
.
logits_processors
params
.
logits_processors
=
None
lp_bytes
=
cloudpickle
.
dumps
(
logits_processors
)
else
:
lp_bytes
=
None
request_bytes
=
pickle
.
dumps
(
RPCProcessRequest
(
inputs
=
inputs
,
params
=
params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts
=
(
request_bytes
,
lp_bytes
)
if
lp_bytes
else
(
request_bytes
,
)
await
self
.
input_socket
.
send_multipart
(
parts
,
copy
=
False
)
# 4) Stream the RequestOutputs from the output queue. Note
# that the output_loop pushes RequestOutput objects to this
# queue after pulling them from the zmq socket.
finished
=
False
try
:
while
not
finished
:
request_output
=
await
queue
.
get
()
if
isinstance
(
request_output
,
BaseException
):
raise
request_output
finished
=
request_output
.
finished
yield
request_output
finally
:
# Request was canceled by the client.
if
not
finished
and
not
self
.
errored
:
await
self
.
abort
(
request_id
)
finally
:
self
.
output_queues
.
pop
(
request_id
)
async
def
start_profile
(
self
)
->
None
:
"""Start profiling the engine"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUProfileRequest
.
START_PROFILE
,
socket
=
self
.
input_socket
)
async
def
stop_profile
(
self
)
->
None
:
"""Stop profiling the engine"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUProfileRequest
.
STOP_PROFILE
,
socket
=
self
.
input_socket
)
vllm/engine/multiprocessing/engine.py
0 → 100644
View file @
539aa992
import
pickle
import
signal
import
threading
import
time
from
contextlib
import
contextmanager
from
typing
import
Iterator
,
List
,
Optional
,
Union
import
cloudpickle
import
zmq
from
vllm
import
AsyncEngineArgs
,
LLMEngine
,
SamplingParams
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.engine.multiprocessing
import
(
ENGINE_DEAD_ERROR
,
IPC_DATA_EXT
,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
REQUEST_OUTPUTS_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCProcessRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.usage.usage_lib
import
UsageContext
CONFIG_TYPE
=
Union
[
ModelConfig
,
DecodingConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
]
logger
=
init_logger
(
__name__
)
POLLING_TIMEOUT_MS
=
10000
HEALTHY_RESPONSE
=
(
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
)
class
MQLLMEngine
:
"""A multiprocessing wrapper for :class:`LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to enable use
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The :class:`LLMEngine` generate or encode process is kicked off when a new
RPCProcessRequest is received by the input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
:class:`LLMEngine.step()`, and sends the RequestOutputs back over
the output_socket.
If use_async_sockets is set, the logic associated with reading new
requests from the socket and sending data to the socket is passed
as a callback to the llm_engine, which calls the logic asynchronously
such that the IPC can be overlapped with the GPU.
Args:
ipc_path: Base path for zeromq interprocess messaging
use_async_sockets: Whether to make send/recv async with GPU
log_requests: Whether to log the requests.
*args: Arguments for :class:`LLMEngine`.
**kwargs: Arguments for :class:`LLMEngine`.
"""
def
__init__
(
self
,
ipc_path
:
str
,
use_async_sockets
:
bool
,
*
args
,
log_requests
:
bool
=
True
,
**
kwargs
)
->
None
:
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs
=
True
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
,
use_cached_outputs
=
use_cached_outputs
)
self
.
log_requests
=
log_requests
self
.
use_async_sockets
=
use_async_sockets
if
self
.
use_async_sockets
:
self
.
engine
.
process_request_outputs_callback
=
\
self
.
_async_socket_engine_callback
self
.
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
# Receive input from the client.
self
.
input_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PULL
)
self
.
input_socket
.
bind
(
f
"
{
ipc_path
}{
IPC_INPUT_EXT
}
"
)
# Send output stream back to client.
self
.
output_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PUSH
)
self
.
output_socket
.
bind
(
f
"
{
ipc_path
}{
IPC_OUTPUT_EXT
}
"
)
# Send heartbeats back to client.
self
.
heartbeat_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PUSH
)
self
.
heartbeat_socket
.
bind
(
f
"
{
ipc_path
}{
IPC_HEALTH_EXT
}
"
)
# IPC path for the data socket.
self
.
data_ipc_path
=
f
"
{
ipc_path
}{
IPC_DATA_EXT
}
"
# Error state.
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
# Heartbeat thread
self
.
heartbeat_thread
=
threading
.
Thread
(
target
=
self
.
_heartbeat_loop
,
daemon
=
True
)
self
.
_heartbeat_stop_event
=
threading
.
Event
()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self
.
heartbeat_interval_seconds
=
VLLM_RPC_TIMEOUT
/
5000.0
self
.
_last_alive_time
=
time
.
time
()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self
.
last_alive_threshold
=
VLLM_RPC_TIMEOUT
*
3.0
/
1000.0
@
property
def
dead_error
(
self
)
->
BaseException
:
if
self
.
_errored_with
is
not
None
:
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
else
:
return
ENGINE_DEAD_ERROR
()
@
classmethod
def
from_engine_args
(
cls
,
engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
ipc_path
:
str
):
"""Creates an MQLLMEngine from the engine arguments."""
engine_config
=
engine_args
.
create_engine_config
()
executor_class
=
LLMEngine
.
_get_executor_cls
(
engine_config
)
return
cls
(
ipc_path
=
ipc_path
,
use_async_sockets
=
engine_config
.
model_config
.
use_async_output_proc
,
**
engine_config
.
to_dict
(),
executor_class
=
executor_class
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_stats
=
not
engine_args
.
disable_log_stats
,
usage_context
=
usage_context
)
def
start
(
self
):
try
:
try
:
logger
.
debug
(
"Starting Startup Loop."
)
self
.
run_startup_loop
()
logger
.
debug
(
"Starting heartbeat thread"
)
self
.
heartbeat_thread
.
start
()
logger
.
debug
(
"Starting Engine Loop."
)
self
.
run_engine_loop
()
except
Exception
as
e
:
logger
.
exception
(
repr
(
e
))
except
KeyboardInterrupt
:
logger
.
debug
(
"Shutting down MQLLMEngine."
)
finally
:
logger
.
debug
(
"MQLLMEngine is shut down."
)
self
.
cleanup
()
def
cleanup
(
self
):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self
.
_heartbeat_stop_event
.
set
()
self
.
ctx
.
destroy
(
linger
=
0
)
del
self
.
engine
@
contextmanager
def
make_data_socket
(
self
)
->
Iterator
[
zmq
.
Socket
]:
# type: ignore[name-defined]
socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
ROUTER
)
try
:
socket
.
bind
(
self
.
data_ipc_path
)
yield
socket
finally
:
socket
.
close
(
linger
=
0
)
def
run_startup_loop
(
self
)
->
None
:
"""Startup loop for sending data from Engine -> Client."""
with
self
.
make_data_socket
()
as
socket
:
response
:
Union
[
RPCStartupResponse
,
BaseException
]
try
:
identity
,
message
=
socket
.
recv_multipart
(
copy
=
False
)
request
:
RPCStartupRequest
=
pickle
.
loads
(
message
.
buffer
)
# Handle the query from the Client.
if
request
==
RPCStartupRequest
.
IS_SERVER_READY
:
tracing_enabled
=
self
.
engine
.
is_tracing_enabled
()
response
=
RPCStartupResponse
(
tracing_enabled
=
tracing_enabled
)
except
Exception
as
e
:
response
=
e
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
response
)),
copy
=
False
)
def
run_engine_loop
(
self
):
"""Core busy loop of the LLMEngine."""
while
True
:
self
.
_alive
()
if
not
self
.
engine
.
has_unfinished_requests
():
# Poll until there is work to do.
while
self
.
input_socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
self
.
_alive
()
self
.
engine
.
do_log_stats
()
logger
.
debug
(
"Waiting for new requests in engine loop."
)
# Handle any input from the client.
self
.
handle_new_input
()
# Engine step.
request_outputs
=
self
.
engine_step
()
# Send request outputs (if async, done in engine_step callback).
if
not
self
.
use_async_sockets
:
self
.
_send_outputs
(
request_outputs
)
def
engine_step
(
self
)
->
List
[
RequestOutput
]:
"""Engine step wrapper with error handling."""
try
:
return
self
.
engine
.
step
()
except
SystemExit
:
raise
except
BaseException
as
e
:
self
.
_set_errored
(
e
)
rpc_err
=
RPCError
(
request_id
=
None
,
is_engine_errored
=
True
,
exception
=
e
)
self
.
_send_outputs
(
rpc_err
)
raise
e
def
handle_new_input
(
self
):
"""Handle new input from the socket"""
try
:
while
self
.
input_socket
.
poll
(
timeout
=
0
)
!=
0
:
frames
=
self
.
input_socket
.
recv_multipart
(
copy
=
False
)
request
=
pickle
.
loads
(
frames
[
0
].
buffer
)
if
isinstance
(
request
,
RPCProcessRequest
):
if
len
(
frames
)
>
1
:
# Use cloudpickle for logits processors
assert
isinstance
(
request
.
params
,
SamplingParams
)
lprocs
=
cloudpickle
.
loads
(
frames
[
1
].
buffer
)
request
.
params
.
logits_processors
=
lprocs
self
.
_handle_process_request
(
request
)
elif
isinstance
(
request
,
RPCAbortRequest
):
self
.
_handle_abort_request
(
request
)
elif
isinstance
(
request
,
RPCUProfileRequest
):
if
request
==
RPCUProfileRequest
.
START_PROFILE
:
self
.
start_profile
()
else
:
self
.
stop_profile
()
else
:
raise
ValueError
(
"Unknown RPCRequest Type: "
f
"
{
type
(
request
)
}
"
)
except
Exception
as
e
:
self
.
_set_errored
(
e
)
self
.
_send_unhealthy
(
e
)
raise
e
def
_handle_process_request
(
self
,
request
:
RPCProcessRequest
):
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
request_id
=
request
.
request_id
if
self
.
_errored_with
is
not
None
:
rpc_err
=
RPCError
(
request_id
=
request_id
,
is_engine_errored
=
True
,
exception
=
ENGINE_DEAD_ERROR
(
self
.
_errored_with
))
self
.
_send_outputs
(
rpc_err
)
try
:
self
.
engine
.
add_request
(
request_id
=
request_id
,
inputs
=
request
.
inputs
,
params
=
request
.
params
,
lora_request
=
request
.
lora_request
,
trace_headers
=
request
.
trace_headers
,
prompt_adapter_request
=
request
.
prompt_adapter_request
)
if
self
.
log_requests
:
logger
.
info
(
"Added request %s."
,
request
.
request_id
)
except
Exception
as
e
:
# We do not set self._errored = True here, since the error
# is due to an issue adding this request to the engine,
# rather than an issue with the engine itself.
is_errored
=
self
.
_errored_with
is
not
None
rpc_err
=
RPCError
(
request_id
=
request_id
,
is_engine_errored
=
is_errored
,
exception
=
e
)
self
.
_send_outputs
(
rpc_err
)
# Remove request from the engine.
self
.
engine
.
abort_request
(
request_id
)
def
_handle_abort_request
(
self
,
request
:
RPCAbortRequest
):
self
.
engine
.
abort_request
(
request
.
request_id
)
if
self
.
log_requests
:
logger
.
info
(
"Aborted request %s."
,
request
.
request_id
)
def
_heartbeat_loop
(
self
):
while
not
self
.
_heartbeat_stop_event
.
wait
(
timeout
=
self
.
heartbeat_interval_seconds
):
# Loops until the stop event is set
self
.
_heartbeat
()
logger
.
debug
(
"Exiting MQLLMEngine heartbeat thread"
)
def
_heartbeat
(
self
):
# Send unhealthy if engine has already errored
if
self
.
_errored_with
is
not
None
:
self
.
_send_unhealthy
(
self
.
_errored_with
)
# Check for life of the main loop
elif
time
.
time
()
-
self
.
_last_alive_time
>
self
.
last_alive_threshold
:
self
.
_send_unhealthy
(
RuntimeError
(
"Engine loop has died"
))
else
:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try
:
self
.
engine
.
check_health
()
self
.
_send_healthy
()
except
Exception
as
e
:
self
.
_set_errored
(
e
)
self
.
_send_unhealthy
(
e
)
def
_send_outputs
(
self
,
outputs
:
REQUEST_OUTPUTS_T
):
"""Send List of RequestOutput to RPCClient."""
if
outputs
:
output_bytes
=
pickle
.
dumps
(
outputs
)
self
.
output_socket
.
send_multipart
((
output_bytes
,
),
copy
=
False
)
def
_send_healthy
(
self
):
"""Send HEALTHY message to RPCClient."""
if
not
self
.
heartbeat_socket
.
closed
:
self
.
heartbeat_socket
.
send_multipart
(
HEALTHY_RESPONSE
,
copy
=
False
)
def
_send_unhealthy
(
self
,
error
:
BaseException
):
"""Send UNHEALTHY message to RPCClient."""
if
not
self
.
heartbeat_socket
.
closed
:
error_bytes
=
pickle
.
dumps
(
error
)
self
.
heartbeat_socket
.
send_multipart
((
error_bytes
,
),
copy
=
False
)
def
_async_socket_engine_callback
(
self
,
request_outputs
:
REQUEST_OUTPUTS_T
):
"""Callback used by engine to make socket handling async with GPU."""
self
.
_send_outputs
(
request_outputs
)
self
.
handle_new_input
()
def
_set_errored
(
self
,
e
:
BaseException
):
"""Log and set errored status if this is the first issue."""
if
self
.
_errored_with
is
None
:
self
.
_errored_with
=
e
def
_alive
(
self
):
self
.
_last_alive_time
=
time
.
time
()
def
start_profile
(
self
)
->
None
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
self
.
engine
.
model_executor
.
start_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"start_profile"
)
def
stop_profile
(
self
)
->
None
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
self
.
engine
.
model_executor
.
stop_profile
()
else
:
self
.
engine
.
model_executor
.
_run_workers
(
"stop_profile"
)
def
run_mp_engine
(
engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
ipc_path
:
str
):
def
signal_handler
(
*
_
)
->
None
:
# Interrupt server on sigterm
raise
KeyboardInterrupt
(
"MQLLMEngine terminated"
)
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
engine
=
MQLLMEngine
.
from_engine_args
(
engine_args
=
engine_args
,
usage_context
=
usage_context
,
ipc_path
=
ipc_path
)
engine
.
start
()
vllm/engine/output_processor/multi_step.py
View file @
539aa992
...
@@ -9,8 +9,8 @@ from vllm.engine.output_processor.single_step import (
...
@@ -9,8 +9,8 @@ from vllm.engine.output_processor.single_step import (
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
Sequence
,
SequenceGroup
,
SequenceOutput
,
SequenceStatus
)
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -110,10 +110,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -110,10 +110,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# we can take the first sample.
# we can take the first sample.
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
#
-1 means the output token is not
valid (eg. due to spec decode
#
entries in sample tokens may be in
valid (eg. due to spec decode
# rejecting tokens).
# rejecting tokens).
valid_samples
=
[
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
-
1
sample
for
sample
in
samples
if
sample
.
output_token
!=
VLLM_INVALID_TOKEN_ID
]
]
assert
valid_samples
assert
valid_samples
...
...
vllm/engine/protocol.py
View file @
539aa992
...
@@ -14,8 +14,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
...
@@ -14,8 +14,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
@
runtime_checkable
@
runtime_checkable
class
Async
EngineClient
(
Protocol
):
class
EngineClient
(
Protocol
):
"""Protocol class for Clients to
AsyncLLM
Engine"""
"""Protocol class for Clients to Engine"""
@
property
@
property
def
is_running
(
self
)
->
bool
:
def
is_running
(
self
)
->
bool
:
...
@@ -30,8 +30,8 @@ class AsyncEngineClient(Protocol):
...
@@ -30,8 +30,8 @@ class AsyncEngineClient(Protocol):
...
...
@
property
@
property
def
limit_concurrency
(
self
)
->
Optional
[
int
]
:
def
dead_error
(
self
)
->
BaseException
:
"""Maximum number of concurrently running requests."""
...
def
generate
(
def
generate
(
self
,
self
,
...
...
vllm/entrypoints/api_server.py
View file @
539aa992
...
@@ -121,7 +121,6 @@ async def run_server(args: Namespace,
...
@@ -121,7 +121,6 @@ async def run_server(args: Namespace,
shutdown_task
=
await
serve_http
(
shutdown_task
=
await
serve_http
(
app
,
app
,
engine
=
engine
,
host
=
args
.
host
,
host
=
args
.
host
,
port
=
args
.
port
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
log_level
=
args
.
log_level
,
...
...
vllm/entrypoints/chat_utils.py
View file @
539aa992
...
@@ -159,6 +159,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -159,6 +159,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config
.
image_token_index
)
hf_config
.
image_token_index
)
if
model_type
in
(
"chameleon"
,
"internvl_chat"
):
if
model_type
in
(
"chameleon"
,
"internvl_chat"
):
return
"<image>"
return
"<image>"
if
model_type
==
"mllama"
:
return
"<|image|>"
if
model_type
==
"qwen2_vl"
:
if
model_type
==
"qwen2_vl"
:
return
"<|vision_start|><|image_pad|><|vision_end|>"
return
"<|vision_start|><|image_pad|><|vision_end|>"
...
@@ -358,6 +360,7 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
...
@@ -358,6 +360,7 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser
=
partial
(
cast
,
ChatCompletionContentPartImageParam
)
_ImageParser
=
partial
(
cast
,
ChatCompletionContentPartImageParam
)
_AudioParser
=
partial
(
cast
,
ChatCompletionContentPartAudioParam
)
_AudioParser
=
partial
(
cast
,
ChatCompletionContentPartAudioParam
)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
MODEL_KEEP_MULTI_MODAL_CONTENT
=
{
'mllama'
}
def
_parse_chat_message_content_parts
(
def
_parse_chat_message_content_parts
(
...
@@ -368,7 +371,11 @@ def _parse_chat_message_content_parts(
...
@@ -368,7 +371,11 @@ def _parse_chat_message_content_parts(
texts
:
List
[
str
]
=
[]
texts
:
List
[
str
]
=
[]
mm_parser
=
mm_tracker
.
create_parser
()
mm_parser
=
mm_tracker
.
create_parser
()
keep_multimodal_content
=
\
mm_tracker
.
_model_config
.
hf_config
.
model_type
in
\
MODEL_KEEP_MULTI_MODAL_CONTENT
has_image
=
False
for
part
in
parts
:
for
part
in
parts
:
part_type
=
part
[
"type"
]
part_type
=
part
[
"type"
]
if
part_type
==
"text"
:
if
part_type
==
"text"
:
...
@@ -383,6 +390,7 @@ def _parse_chat_message_content_parts(
...
@@ -383,6 +390,7 @@ def _parse_chat_message_content_parts(
"will be ignored."
)
"will be ignored."
)
mm_parser
.
parse_image
(
image_url
[
"url"
])
mm_parser
.
parse_image
(
image_url
[
"url"
])
has_image
=
True
elif
part_type
==
"audio_url"
:
elif
part_type
==
"audio_url"
:
audio_url
=
_AudioParser
(
part
)[
"audio_url"
]
audio_url
=
_AudioParser
(
part
)[
"audio_url"
]
...
@@ -394,12 +402,20 @@ def _parse_chat_message_content_parts(
...
@@ -394,12 +402,20 @@ def _parse_chat_message_content_parts(
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
text_prompt
=
"
\n
"
.
join
(
texts
)
text_prompt
=
"
\n
"
.
join
(
texts
)
mm_placeholder_counts
=
mm_parser
.
mm_placeholder_counts
()
if
keep_multimodal_content
:
if
mm_placeholder_counts
:
text_prompt
=
"
\n
"
.
join
(
texts
)
text_prompt
=
_get_full_multimodal_text_prompt
(
mm_placeholder_counts
,
role_content
=
[{
'type'
:
'text'
,
'text'
:
text_prompt
}]
text_prompt
)
if
has_image
:
return
[
ConversationMessage
(
role
=
role
,
content
=
text_prompt
)]
role_content
=
[{
'type'
:
'image'
}]
+
role_content
return
[
ConversationMessage
(
role
=
role
,
content
=
role_content
)]
# type: ignore
else
:
mm_placeholder_counts
=
mm_parser
.
mm_placeholder_counts
()
if
mm_placeholder_counts
:
text_prompt
=
_get_full_multimodal_text_prompt
(
mm_placeholder_counts
,
text_prompt
)
return
[
ConversationMessage
(
role
=
role
,
content
=
text_prompt
)]
# No need to validate using Pydantic again
# No need to validate using Pydantic again
...
...
vllm/entrypoints/launcher.py
View file @
539aa992
...
@@ -4,19 +4,18 @@ from http import HTTPStatus
...
@@ -4,19 +4,18 @@ from http import HTTPStatus
from
typing
import
Any
from
typing
import
Any
import
uvicorn
import
uvicorn
from
fastapi
import
FastAPI
,
Response
from
fastapi
import
FastAPI
,
Request
,
Response
from
vllm
import
envs
from
vllm
import
envs
from
vllm.engine.async_llm_engine
import
AsyncEngineDeadError
from
vllm.engine.async_llm_engine
import
AsyncEngineDeadError
from
vllm.engine.
protocol
import
Async
Engine
Client
from
vllm.engine.
multiprocessing
import
MQ
Engine
DeadError
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
find_process_using_port
from
vllm.utils
import
find_process_using_port
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
async
def
serve_http
(
app
:
FastAPI
,
engine
:
AsyncEngineClient
,
async
def
serve_http
(
app
:
FastAPI
,
**
uvicorn_kwargs
:
Any
):
**
uvicorn_kwargs
:
Any
):
logger
.
info
(
"Available routes are:"
)
logger
.
info
(
"Available routes are:"
)
for
route
in
app
.
routes
:
for
route
in
app
.
routes
:
methods
=
getattr
(
route
,
"methods"
,
None
)
methods
=
getattr
(
route
,
"methods"
,
None
)
...
@@ -27,18 +26,9 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
...
@@ -27,18 +26,9 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
logger
.
info
(
"Route: %s, Methods: %s"
,
path
,
', '
.
join
(
methods
))
logger
.
info
(
"Route: %s, Methods: %s"
,
path
,
', '
.
join
(
methods
))
# Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if
engine
.
limit_concurrency
is
not
None
:
logger
.
info
(
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with "
"--disable-frontend-multiprocessing"
,
engine
.
limit_concurrency
)
uvicorn_kwargs
[
"limit_concurrency"
]
=
engine
.
limit_concurrency
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
server
=
uvicorn
.
Server
(
config
)
server
=
uvicorn
.
Server
(
config
)
_add_shutdown_handlers
(
app
,
server
,
engine
)
_add_shutdown_handlers
(
app
,
server
)
loop
=
asyncio
.
get_running_loop
()
loop
=
asyncio
.
get_running_loop
()
...
@@ -64,19 +54,19 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
...
@@ -64,19 +54,19 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
logger
.
debug
(
logger
.
debug
(
"port %s is used by process %s launched with command:
\n
%s"
,
"port %s is used by process %s launched with command:
\n
%s"
,
port
,
process
,
" "
.
join
(
process
.
cmdline
()))
port
,
process
,
" "
.
join
(
process
.
cmdline
()))
logger
.
info
(
"
Gracefully stopping http
server"
)
logger
.
info
(
"
Shutting down FastAPI HTTP
server
.
"
)
return
server
.
shutdown
()
return
server
.
shutdown
()
def
_add_shutdown_handlers
(
app
:
FastAPI
,
server
:
uvicorn
.
Server
,
def
_add_shutdown_handlers
(
app
:
FastAPI
,
server
:
uvicorn
.
Server
)
->
None
:
engine
:
AsyncEngineClient
)
->
None
:
"""Adds handlers for fatal errors that should crash the server"""
"""Adds handlers for fatal errors that should crash the server"""
@
app
.
exception_handler
(
RuntimeError
)
@
app
.
exception_handler
(
RuntimeError
)
async
def
runtime_error_handler
(
_
,
__
):
async
def
runtime_error_handler
(
request
:
Request
,
__
):
"""On generic runtime error, check to see if the engine has died.
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine
=
request
.
app
.
state
.
engine_client
if
(
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
and
engine
.
errored
if
(
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
and
engine
.
errored
and
not
engine
.
is_running
):
and
not
engine
.
is_running
):
logger
.
fatal
(
"AsyncLLMEngine has failed, terminating server "
logger
.
fatal
(
"AsyncLLMEngine has failed, terminating server "
...
@@ -91,7 +81,7 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
...
@@ -91,7 +81,7 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
@
app
.
exception_handler
(
AsyncEngineDeadError
)
@
app
.
exception_handler
(
AsyncEngineDeadError
)
async
def
engine_dead_handler
(
_
,
__
):
async
def
async_
engine_dead_handler
(
_
,
__
):
"""Kill the server if the async engine is already dead. It will
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
not handle any further requests."""
if
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
:
if
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
:
...
@@ -100,3 +90,14 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
...
@@ -100,3 +90,14 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
server
.
should_exit
=
True
server
.
should_exit
=
True
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
@
app
.
exception_handler
(
MQEngineDeadError
)
async
def
mq_engine_dead_handler
(
_
,
__
):
"""Kill the server if the mq engine is already dead. It will
not handle any further requests."""
if
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
:
logger
.
fatal
(
"MQLLMEngine is already dead, terminating server "
"process"
)
server
.
should_exit
=
True
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
vllm/entrypoints/llm.py
View file @
539aa992
import
itertools
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
ClassVar
,
List
,
Optional
,
Sequence
,
Union
,
cast
,
overload
from
dataclasses
import
dataclass
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
,
cast
,
overload
)
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -29,6 +32,37 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
...
@@ -29,6 +32,37 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
dataclass
class
BeamSearchSequence
:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens
:
List
[
int
]
cum_logprob
:
float
=
0.0
text
:
Optional
[
str
]
=
None
@
dataclass
class
BeamSearchOutput
:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences
:
List
[
BeamSearchSequence
]
class
BeamSearchInstance
:
def
__init__
(
self
,
prompt_tokens
:
List
[
int
]):
self
.
beams
:
List
[
BeamSearchSequence
]
=
[
BeamSearchSequence
(
tokens
=
prompt_tokens
)
]
self
.
completed
:
List
[
BeamSearchSequence
]
=
[]
class
LLM
:
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
"""An LLM for generating texts from given prompts and sampling parameters.
...
@@ -88,7 +122,9 @@ class LLM:
...
@@ -88,7 +122,9 @@ class LLM:
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
When a sequence has context length larger than this, we fall back
to eager mode.
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_custom_all_reduce: See ParallelConfig
disable_custom_all_reduce: See ParallelConfig
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
:ref:`engine_args`)
...
@@ -131,15 +167,14 @@ class LLM:
...
@@ -131,15 +167,14 @@ class LLM:
max_seq_len_to_capture
:
int
=
8192
,
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
'''
'''
LLM constructor.
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False for decoder-only models and True
it defaults to False.
for encoder/decoder models, since encoder/decoder models
do not currently support CUDAGraph.
'''
'''
if
"disable_log_stats"
not
in
kwargs
:
if
"disable_log_stats"
not
in
kwargs
:
...
@@ -173,6 +208,7 @@ class LLM:
...
@@ -173,6 +208,7 @@ class LLM:
max_seq_len_to_capture
=
max_seq_len_to_capture
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_async_output_proc
=
disable_async_output_proc
,
disable_async_output_proc
=
disable_async_output_proc
,
mm_processor_kwargs
=
mm_processor_kwargs
,
**
kwargs
,
**
kwargs
,
)
)
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
...
@@ -284,7 +320,8 @@ class LLM:
...
@@ -284,7 +320,8 @@ class LLM:
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
GuidedDecodingRequest
]]
=
None
,
priority
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
"""Generates the completions for the input prompts.
...
@@ -303,6 +340,8 @@ class LLM:
...
@@ -303,6 +340,8 @@ class LLM:
lora_request: LoRA request to use for generation, if any.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
generation, if any.
priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled.
Returns:
Returns:
A list of ``RequestOutput`` objects containing the
A list of ``RequestOutput`` objects containing the
...
@@ -343,20 +382,122 @@ class LLM:
...
@@ -343,20 +382,122 @@ class LLM:
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
guided_options
=
guided_options_request
)
guided_options
=
guided_options_request
,
priority
=
priority
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
def
beam_search
(
self
,
prompts
:
List
[
Union
[
str
,
List
[
int
]]],
beam_width
:
int
,
max_tokens
:
int
,
ignore_eos
:
bool
=
False
,
)
->
List
[
BeamSearchOutput
]:
"""
Generate sequences using beam search.
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
tokenizer
=
self
.
get_tokenizer
()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
0.0
)
instances
:
List
[
BeamSearchInstance
]
=
[]
for
prompt
in
prompts
:
prompt_tokens
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
instances
.
append
(
BeamSearchInstance
(
prompt_tokens
))
for
_
in
range
(
max_tokens
):
all_beams
:
List
[
BeamSearchSequence
]
=
list
(
sum
((
instance
.
beams
for
instance
in
instances
),
[]))
pos
=
[
0
]
+
list
(
itertools
.
accumulate
(
len
(
instance
.
beams
)
for
instance
in
instances
))
instance_start_and_end
:
List
[
Tuple
[
int
,
int
]]
=
list
(
zip
(
pos
[:
-
1
],
pos
[
1
:]))
if
len
(
all_beams
)
==
0
:
break
prompts_batch
=
[
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
)
for
beam
in
all_beams
]
# only runs for one step
# we don't need to use tqdm here
output
=
self
.
generate
(
prompts_batch
,
sampling_params
=
beam_search_params
,
use_tqdm
=
False
)
for
(
start
,
end
),
instance
in
zip
(
instance_start_and_end
,
instances
):
instance_new_beams
=
[]
for
i
in
range
(
start
,
end
):
current_beam
=
all_beams
[
i
]
result
=
output
[
i
]
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the max-model-len
# or abortion. we don't need to add it to the new beams.
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
if
token_id
==
tokenizer
.
eos_token_id
and
\
not
ignore_eos
:
instance
.
completed
.
append
(
new_beam
)
else
:
instance_new_beams
.
append
(
new_beam
)
sorted_beams
=
sorted
(
instance_new_beams
,
key
=
lambda
x
:
x
.
cum_logprob
,
reverse
=
True
)
instance
.
beams
=
sorted_beams
[:
beam_width
]
outputs
=
[]
for
instance
in
instances
:
instance
.
completed
.
extend
(
instance
.
beams
)
sorted_completed
=
sorted
(
instance
.
completed
,
key
=
lambda
x
:
x
.
cum_logprob
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
)
outputs
.
append
(
BeamSearchOutput
(
sequences
=
best_beams
))
return
outputs
def
chat
(
def
chat
(
self
,
self
,
messages
:
List
[
ChatCompletionMessageParam
],
messages
:
Union
[
List
[
ChatCompletionMessageParam
],
List
[
List
[
ChatCompletionMessageParam
]]],
sampling_params
:
Optional
[
Union
[
SamplingParams
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
List
[
SamplingParams
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
add_generation_prompt
:
bool
=
True
,
add_generation_prompt
:
bool
=
True
,
tools
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""
"""
Generate responses for a chat conversation.
Generate responses for a chat conversation.
...
@@ -369,8 +510,9 @@ class LLM:
...
@@ -369,8 +510,9 @@ class LLM:
to the OpenAI API.
to the OpenAI API.
Args:
Args:
messages: A single conversation represented as a list of messages.
messages: A list of conversations or a single conversation.
Each message is a dictionary with 'role' and 'content' keys.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a single value, it is applied to every prompt. When it
...
@@ -387,40 +529,56 @@ class LLM:
...
@@ -387,40 +529,56 @@ class LLM:
A list of ``RequestOutput`` objects containing the generated
A list of ``RequestOutput`` objects containing the generated
responses in the same order as the input messages.
responses in the same order as the input messages.
"""
"""
list_of_messages
:
List
[
List
[
ChatCompletionMessageParam
]]
tokenizer
=
self
.
get_tokenizer
()
# Handle multi and single conversations
model_config
=
self
.
llm_engine
.
get_model_config
()
if
is_list_of
(
messages
,
list
):
# messages is List[List[...]]
conversation
,
mm_data
=
parse_chat_messages
(
messages
,
model_config
,
list_of_messages
=
messages
tokenizer
)
prompt
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
prompt
=
apply_mistral_chat_template
(
tokenizer
,
messages
=
messages
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
)
else
:
else
:
prompt
=
apply_hf_chat_template
(
# messages is List[...]
tokenizer
,
list_of_messages
=
[
messages
]
conversation
=
conversation
,
chat_template
=
chat_template
,
prompts
:
List
[
Union
[
TokensPrompt
,
TextPrompt
]]
=
[]
add_generation_prompt
=
add_generation_prompt
,
)
for
msgs
in
list_of_messages
:
tokenizer
=
self
.
get_tokenizer
()
model_config
=
self
.
llm_engine
.
get_model_config
()
conversation
,
mm_data
=
parse_chat_messages
(
msgs
,
model_config
,
tokenizer
)
prompt_data
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
prompt_data
=
apply_mistral_chat_template
(
tokenizer
,
messages
=
msgs
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
tools
=
tools
,
)
else
:
prompt_data
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
tools
=
tools
,
)
prompt
:
Union
[
TokensPrompt
,
TextPrompt
]
if
is_list_of
(
prompt_data
,
int
):
prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_data
)
else
:
prompt
=
TextPrompt
(
prompt
=
prompt_data
)
inputs
:
PromptInputs
if
mm_data
is
not
None
:
if
is_list_of
(
prompt
,
int
):
prompt
[
"multi_modal_data"
]
=
mm_data
inputs
=
TokensPrompt
(
prompt_token_ids
=
prompt
)
else
:
inputs
=
TextPrompt
(
prompt
=
prompt
)
if
mm_data
is
not
None
:
prompts
.
append
(
prompt
)
inputs
[
"multi_modal_data"
]
=
mm_data
return
self
.
generate
(
return
self
.
generate
(
inpu
ts
,
promp
ts
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
...
@@ -628,6 +786,7 @@ class LLM:
...
@@ -628,6 +786,7 @@ class LLM:
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
priority
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
)
->
None
:
if
isinstance
(
inputs
,
(
str
,
dict
)):
if
isinstance
(
inputs
,
(
str
,
dict
)):
# Convert a single prompt to a list.
# Convert a single prompt to a list.
...
@@ -657,6 +816,7 @@ class LLM:
...
@@ -657,6 +816,7 @@ class LLM:
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
,
Sequence
)
else
lora_request
,
lora_request
,
Sequence
)
else
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
[
i
]
if
priority
else
0
,
)
)
def
_add_request
(
def
_add_request
(
...
@@ -665,6 +825,7 @@ class LLM:
...
@@ -665,6 +825,7 @@ class LLM:
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
)
->
None
:
request_id
=
str
(
next
(
self
.
request_counter
))
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
self
.
llm_engine
.
add_request
(
...
@@ -673,6 +834,7 @@ class LLM:
...
@@ -673,6 +834,7 @@ class LLM:
params
,
params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
)
)
def
_add_guided_processor
(
def
_add_guided_processor
(
...
...
vllm/entrypoints/openai/api_server.py
View file @
539aa992
...
@@ -4,16 +4,21 @@ import inspect
...
@@ -4,16 +4,21 @@ import inspect
import
multiprocessing
import
multiprocessing
import
os
import
os
import
re
import
re
import
signal
import
socket
import
tempfile
import
tempfile
from
argparse
import
Namespace
from
argparse
import
Namespace
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
functools
import
partial
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
AsyncIterator
,
Optional
,
Set
from
typing
import
AsyncIterator
,
Set
import
uvloop
from
fastapi
import
APIRouter
,
FastAPI
,
Request
from
fastapi
import
APIRouter
,
FastAPI
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
starlette.datastructures
import
State
from
starlette.routing
import
Mount
from
starlette.routing
import
Mount
from
typing_extensions
import
assert_never
from
typing_extensions
import
assert_never
...
@@ -21,7 +26,9 @@ import vllm.envs as envs
...
@@ -21,7 +26,9 @@ import vllm.envs as envs
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
from
vllm.engine.multiprocessing.engine
import
run_mp_engine
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
...
@@ -39,12 +46,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...
@@ -39,12 +46,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeRequest
,
TokenizeRequest
,
TokenizeResponse
,
TokenizeResponse
,
UnloadLoraAdapterRequest
)
UnloadLoraAdapterRequest
)
from
vllm.entrypoints.openai.rpc.client
import
AsyncEngineRPCClient
from
vllm.entrypoints.openai.rpc.server
import
run_rpc_server
# yapf: enable
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
from
vllm.entrypoints.openai.serving_tokenization
import
(
from
vllm.entrypoints.openai.serving_tokenization
import
(
OpenAIServingTokenization
)
OpenAIServingTokenization
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -54,12 +60,6 @@ from vllm.version import __version__ as VLLM_VERSION
...
@@ -54,12 +60,6 @@ from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
TIMEOUT_KEEP_ALIVE
=
5
# seconds
async_engine_client
:
AsyncEngineClient
engine_args
:
AsyncEngineArgs
openai_serving_chat
:
OpenAIServingChat
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_embedding
:
OpenAIServingEmbedding
openai_serving_tokenization
:
OpenAIServingTokenization
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
...
@@ -68,49 +68,42 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
...
@@ -68,49 +68,42 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks
:
Set
[
asyncio
.
Task
]
=
set
()
_running_tasks
:
Set
[
asyncio
.
Task
]
=
set
()
def
model_is_embedding
(
model_name
:
str
,
trust_remote_code
:
bool
,
quantization
:
Optional
[
str
])
->
bool
:
return
ModelConfig
(
model
=
model_name
,
tokenizer
=
model_name
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
trust_remote_code
,
quantization
=
quantization
,
seed
=
0
,
dtype
=
"auto"
).
embedding_mode
@
asynccontextmanager
@
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
async
def
lifespan
(
app
:
FastAPI
):
try
:
async
def
_force_log
():
if
app
.
state
.
log_stats
:
while
True
:
engine_client
:
EngineClient
=
app
.
state
.
engine_client
await
asyncio
.
sleep
(
10
)
await
async_engine_client
.
do_log_stats
()
async
def
_force_log
():
while
True
:
if
not
engine_args
.
disable_log_stats
:
await
asyncio
.
sleep
(
10.
)
task
=
asyncio
.
create_task
(
_force_log
())
await
engine_client
.
do_log_stats
()
_running_tasks
.
add
(
task
)
task
.
add_done_callback
(
_running_tasks
.
remove
)
task
=
asyncio
.
create_task
(
_force_log
())
_running_tasks
.
add
(
task
)
yield
task
.
add_done_callback
(
_running_tasks
.
remove
)
else
:
task
=
None
try
:
yield
finally
:
if
task
is
not
None
:
task
.
cancel
()
finally
:
# Ensure app state including engine ref is gc'd
del
app
.
state
@
asynccontextmanager
@
asynccontextmanager
async
def
build_async_engine_client
(
async
def
build_async_engine_client
(
args
:
Namespace
)
->
AsyncIterator
[
Optional
[
Async
EngineClient
]
]
:
args
:
Namespace
)
->
AsyncIterator
[
EngineClient
]:
# Context manager to handle
async_
engine_client lifecycle
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
# Ensures everything is shutdown and cleaned up on error/exit
global
engine_args
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
# Backend itself still global for the silly lil' health handler
global
async_engine_client
async
with
build_async_engine_client_from_engine_args
(
async
with
build_async_engine_client_from_engine_args
(
engine_args
,
args
.
disable_frontend_multiprocessing
)
as
engine
:
engine_args
,
args
.
disable_frontend_multiprocessing
)
as
engine
:
async_engine_client
=
engine
# type: ignore[assignment]
yield
engine
yield
engine
...
@@ -118,26 +111,35 @@ async def build_async_engine_client(
...
@@ -118,26 +111,35 @@ async def build_async_engine_client(
async
def
build_async_engine_client_from_engine_args
(
async
def
build_async_engine_client_from_engine_args
(
engine_args
:
AsyncEngineArgs
,
engine_args
:
AsyncEngineArgs
,
disable_frontend_multiprocessing
:
bool
=
False
,
disable_frontend_multiprocessing
:
bool
=
False
,
)
->
AsyncIterator
[
Optional
[
Async
EngineClient
]
]
:
)
->
AsyncIterator
[
EngineClient
]:
"""
"""
Create
Async
EngineClient, either:
Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
Returns the Client or None if the creation failed.
"""
"""
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# Fall back
# TODO: support embedding model via RPC.
# TODO: fill out feature matrix.
if
(
model_is_embedding
(
engine_args
.
model
,
engine_args
.
trust_remote_code
,
if
(
MQLLMEngineClient
.
is_unsupported_config
(
engine_args
)
engine_args
.
quantization
)
or
disable_frontend_multiprocessing
):
or
disable_frontend_multiprocessing
):
engine_client
=
AsyncLLMEngine
.
from_engine_args
(
engine_config
=
engine_args
.
create_engine_config
()
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
uses_ray
=
getattr
(
AsyncLLMEngine
.
_get_executor_cls
(
engine_config
),
try
:
"uses_ray"
,
False
)
yield
engine_client
finally
:
build_engine
=
partial
(
AsyncLLMEngine
.
from_engine_args
,
engine_client
.
shutdown_background_loop
()
engine_args
=
engine_args
,
engine_config
=
engine_config
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
if
uses_ray
:
# Must run in main thread with ray for its signal handlers to work
engine_client
=
build_engine
()
else
:
engine_client
=
await
asyncio
.
get_running_loop
().
run_in_executor
(
None
,
build_engine
)
yield
engine_client
return
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
# Otherwise, use the multiprocessing AsyncLLMEngine.
...
@@ -158,56 +160,58 @@ async def build_async_engine_client_from_engine_args(
...
@@ -158,56 +160,58 @@ async def build_async_engine_client_from_engine_args(
"and vLLM will properly handle cleanup."
)
"and vLLM will properly handle cleanup."
)
# Select random path for IPC.
# Select random path for IPC.
rpc_path
=
get_open_zmq_ipc_path
()
ipc_path
=
get_open_zmq_ipc_path
()
logger
.
info
(
"Multiprocessing frontend to use %s for RPC Path."
,
logger
.
info
(
"Multiprocessing frontend to use %s for IPC Path."
,
rpc_path
)
ipc_path
)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client
=
AsyncEngineRPCClient
(
rpc_path
)
# Start RPCServer in separate process (holds the AsyncLLMEngine).
# Start RPCServer in separate process (holds the LLMEngine).
context
=
multiprocessing
.
get_context
(
"spawn"
)
# the current process might have CUDA context,
# the current process might have CUDA context,
# so we need to spawn a new process
# so we need to spawn a new process
rpc_server_process
=
context
.
Process
(
context
=
multiprocessing
.
get_context
(
"spawn"
)
target
=
run_rpc_server
,
args
=
(
engine_args
,
UsageContext
.
OPENAI_API_SERVER
,
rpc_path
))
engine_process
=
context
.
Process
(
target
=
run_mp_engine
,
rpc_server_process
.
start
()
args
=
(
engine_args
,
logger
.
info
(
"Started engine process with PID %d"
,
UsageContext
.
OPENAI_API_SERVER
,
rpc_server_process
.
pid
)
ipc_path
))
engine_process
.
start
()
logger
.
info
(
"Started engine process with PID %d"
,
engine_process
.
pid
)
# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config
=
engine_args
.
create_engine_config
()
mp_engine_client
=
MQLLMEngineClient
(
ipc_path
,
engine_config
)
try
:
try
:
while
True
:
while
True
:
try
:
try
:
await
rpc
_client
.
setup
()
await
mp_engine
_client
.
setup
()
break
break
except
TimeoutError
:
except
TimeoutError
:
if
not
rpc_server_process
.
is_alive
():
if
not
engine_process
.
is_alive
():
logger
.
error
(
raise
RuntimeError
(
"RPCServer process died before responding "
"Engine process failed to start"
)
from
None
"to readiness probe"
)
yield
None
yield
mp_engine_client
# type: ignore[misc]
return
yield
rpc_client
# type: ignore[misc]
finally
:
finally
:
# Ensure rpc server process was terminated
# Ensure rpc server process was terminated
rpc_server
_process
.
terminate
()
engine
_process
.
terminate
()
# Close all open connections to the backend
# Close all open connections to the backend
rpc
_client
.
close
()
mp_engine
_client
.
close
()
# Wait for server process to join
# Wait for engine process to join
rpc_server_process
.
join
()
engine_process
.
join
(
4
)
if
engine_process
.
exitcode
is
None
:
# Kill if taking longer than 5 seconds to stop
engine_process
.
kill
()
# Lazy import for prometheus multiprocessing.
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
# See https://prometheus.github.io/client_python/multiprocess/
from
prometheus_client
import
multiprocess
from
prometheus_client
import
multiprocess
multiprocess
.
mark_process_dead
(
rpc_server
_process
.
pid
)
multiprocess
.
mark_process_dead
(
engine
_process
.
pid
)
router
=
APIRouter
()
router
=
APIRouter
()
...
@@ -239,16 +243,36 @@ def mount_metrics(app: FastAPI):
...
@@ -239,16 +243,36 @@ def mount_metrics(app: FastAPI):
app
.
routes
.
append
(
metrics_route
)
app
.
routes
.
append
(
metrics_route
)
def
chat
(
request
:
Request
)
->
OpenAIServingChat
:
return
request
.
app
.
state
.
openai_serving_chat
def
completion
(
request
:
Request
)
->
OpenAIServingCompletion
:
return
request
.
app
.
state
.
openai_serving_completion
def
tokenization
(
request
:
Request
)
->
OpenAIServingTokenization
:
return
request
.
app
.
state
.
openai_serving_tokenization
def
embedding
(
request
:
Request
)
->
OpenAIServingEmbedding
:
return
request
.
app
.
state
.
openai_serving_embedding
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
@
router
.
get
(
"/health"
)
@
router
.
get
(
"/health"
)
async
def
health
()
->
Response
:
async
def
health
(
raw_request
:
Request
)
->
Response
:
"""Health check."""
"""Health check."""
await
async_
engine_client
.
check_health
()
await
engine_client
(
raw_request
)
.
check_health
()
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/tokenize"
)
@
router
.
post
(
"/tokenize"
)
async
def
tokenize
(
request
:
TokenizeRequest
):
async
def
tokenize
(
request
:
TokenizeRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_tokenization
.
create_tokenize
(
request
)
generator
=
await
tokenization
(
raw_request
)
.
create_tokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
status_code
=
generator
.
code
)
...
@@ -259,8 +283,8 @@ async def tokenize(request: TokenizeRequest):
...
@@ -259,8 +283,8 @@ async def tokenize(request: TokenizeRequest):
@
router
.
post
(
"/detokenize"
)
@
router
.
post
(
"/detokenize"
)
async
def
detokenize
(
request
:
DetokenizeRequest
):
async
def
detokenize
(
request
:
DetokenizeRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_tokenization
.
create_detokenize
(
request
)
generator
=
await
tokenization
(
raw_request
)
.
create_detokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
status_code
=
generator
.
code
)
...
@@ -271,8 +295,8 @@ async def detokenize(request: DetokenizeRequest):
...
@@ -271,8 +295,8 @@ async def detokenize(request: DetokenizeRequest):
@
router
.
get
(
"/v1/models"
)
@
router
.
get
(
"/v1/models"
)
async
def
show_available_models
():
async
def
show_available_models
(
raw_request
:
Request
):
models
=
await
openai_serving_completion
.
show_available_models
()
models
=
await
completion
(
raw_request
)
.
show_available_models
()
return
JSONResponse
(
content
=
models
.
model_dump
())
return
JSONResponse
(
content
=
models
.
model_dump
())
...
@@ -286,7 +310,7 @@ async def show_version():
...
@@ -286,7 +310,7 @@ async def show_version():
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
raw_request
:
Request
):
raw_request
:
Request
):
generator
=
await
openai_serving_chat
.
create_chat_completion
(
generator
=
await
chat
(
raw_request
)
.
create_chat_completion
(
request
,
raw_request
)
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
...
@@ -301,7 +325,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
...
@@ -301,7 +325,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@
router
.
post
(
"/v1/completions"
)
@
router
.
post
(
"/v1/completions"
)
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_completion
.
create_completion
(
generator
=
await
completion
(
raw_request
)
.
create_completion
(
request
,
raw_request
)
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
return
JSONResponse
(
content
=
generator
.
model_dump
(),
...
@@ -314,7 +338,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -314,7 +338,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@
router
.
post
(
"/v1/embeddings"
)
@
router
.
post
(
"/v1/embeddings"
)
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_embedding
.
create_embedding
(
generator
=
await
embedding
(
raw_request
)
.
create_embedding
(
request
,
raw_request
)
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
return
JSONResponse
(
content
=
generator
.
model_dump
(),
...
@@ -331,16 +355,16 @@ if envs.VLLM_TORCH_PROFILER_DIR:
...
@@ -331,16 +355,16 @@ if envs.VLLM_TORCH_PROFILER_DIR:
"used for local development!"
)
"used for local development!"
)
@
router
.
post
(
"/start_profile"
)
@
router
.
post
(
"/start_profile"
)
async
def
start_profile
():
async
def
start_profile
(
raw_request
:
Request
):
logger
.
info
(
"Starting profiler..."
)
logger
.
info
(
"Starting profiler..."
)
await
async_
engine_client
.
start_profile
()
await
engine_client
(
raw_request
)
.
start_profile
()
logger
.
info
(
"Profiler started."
)
logger
.
info
(
"Profiler started."
)
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/stop_profile"
)
@
router
.
post
(
"/stop_profile"
)
async
def
stop_profile
():
async
def
stop_profile
(
raw_request
:
Request
):
logger
.
info
(
"Stopping profiler..."
)
logger
.
info
(
"Stopping profiler..."
)
await
async_
engine_client
.
stop_profile
()
await
engine_client
(
raw_request
)
.
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
logger
.
info
(
"Profiler stopped."
)
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
...
@@ -351,13 +375,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
...
@@ -351,13 +375,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"This should ONLY be used for local development!"
)
"This should ONLY be used for local development!"
)
@
router
.
post
(
"/v1/load_lora_adapter"
)
@
router
.
post
(
"/v1/load_lora_adapter"
)
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
):
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
,
response
=
await
openai_serving_chat
.
load_lora_adapter
(
request
)
raw_request
:
Request
):
response
=
await
chat
(
raw_request
).
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
load_lora_adapter
(
request
)
response
=
await
completion
(
raw_request
)
.
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
status_code
=
response
.
code
)
...
@@ -365,13 +390,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
...
@@ -365,13 +390,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
return
Response
(
status_code
=
200
,
content
=
response
)
return
Response
(
status_code
=
200
,
content
=
response
)
@
router
.
post
(
"/v1/unload_lora_adapter"
)
@
router
.
post
(
"/v1/unload_lora_adapter"
)
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
):
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
,
response
=
await
openai_serving_chat
.
unload_lora_adapter
(
request
)
raw_request
:
Request
):
response
=
await
chat
(
raw_request
).
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
unload_lora_adapter
(
request
)
response
=
await
completion
(
raw_request
)
.
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
status_code
=
response
.
code
)
...
@@ -380,7 +406,13 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
...
@@ -380,7 +406,13 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
app
=
FastAPI
(
lifespan
=
lifespan
)
if
args
.
disable_fastapi_docs
:
app
=
FastAPI
(
openapi_url
=
None
,
docs_url
=
None
,
redoc_url
=
None
,
lifespan
=
lifespan
)
else
:
app
=
FastAPI
(
lifespan
=
lifespan
)
app
.
include_router
(
router
)
app
.
include_router
(
router
)
app
.
root_path
=
args
.
root_path
app
.
root_path
=
args
.
root_path
...
@@ -396,7 +428,8 @@ def build_app(args: Namespace) -> FastAPI:
...
@@ -396,7 +428,8 @@ def build_app(args: Namespace) -> FastAPI:
@
app
.
exception_handler
(
RequestValidationError
)
@
app
.
exception_handler
(
RequestValidationError
)
async
def
validation_exception_handler
(
_
,
exc
):
async
def
validation_exception_handler
(
_
,
exc
):
err
=
openai_serving_chat
.
create_error_response
(
message
=
str
(
exc
))
chat
=
app
.
state
.
openai_serving_chat
err
=
chat
.
create_error_response
(
message
=
str
(
exc
))
return
JSONResponse
(
err
.
model_dump
(),
return
JSONResponse
(
err
.
model_dump
(),
status_code
=
HTTPStatus
.
BAD_REQUEST
)
status_code
=
HTTPStatus
.
BAD_REQUEST
)
...
@@ -428,33 +461,34 @@ def build_app(args: Namespace) -> FastAPI:
...
@@ -428,33 +461,34 @@ def build_app(args: Namespace) -> FastAPI:
return
app
return
app
async
def
init_app
(
def
init_app_state
(
async_engine_client
:
AsyncEngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
state
:
State
,
args
:
Namespace
,
args
:
Namespace
,
)
->
FastAPI
:
)
->
None
:
app
=
build_app
(
args
)
if
args
.
served_model_name
is
not
None
:
if
args
.
served_model_name
is
not
None
:
served_model_names
=
args
.
served_model_name
served_model_names
=
args
.
served_model_name
else
:
else
:
served_model_names
=
[
args
.
model
]
served_model_names
=
[
args
.
model
]
model_config
=
await
async_engine_client
.
get_model_config
()
if
args
.
disable_log_requests
:
if
args
.
disable_log_requests
:
request_logger
=
None
request_logger
=
None
else
:
else
:
request_logger
=
RequestLogger
(
max_log_len
=
args
.
max_log_len
)
request_logger
=
RequestLogger
(
max_log_len
=
args
.
max_log_len
)
global
openai_serving_chat
base_model_paths
=
[
global
openai_serving_completion
BaseModelPath
(
name
=
name
,
model_path
=
args
.
model
)
global
openai_serving_embedding
for
name
in
served_model_names
global
openai_serving_tokenization
]
openai_serving_chat
=
OpenAIServingChat
(
state
.
engine_client
=
engine_client
async_engine_client
,
state
.
log_stats
=
not
args
.
disable_log_stats
state
.
openai_serving_chat
=
OpenAIServingChat
(
engine_client
,
model_config
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
args
.
response_role
,
args
.
response_role
,
lora_modules
=
args
.
lora_modules
,
lora_modules
=
args
.
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
prompt_adapters
=
args
.
prompt_adapters
,
...
@@ -463,48 +497,54 @@ async def init_app(
...
@@ -463,48 +497,54 @@ async def init_app(
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
enable_auto_tools
=
args
.
enable_auto_tool_choice
,
enable_auto_tools
=
args
.
enable_auto_tool_choice
,
tool_parser
=
args
.
tool_call_parser
)
tool_parser
=
args
.
tool_call_parser
)
openai_serving_completion
=
OpenAIServingCompletion
(
state
.
openai_serving_completion
=
OpenAIServingCompletion
(
async_
engine_client
,
engine_client
,
model_config
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
lora_modules
=
args
.
lora_modules
,
lora_modules
=
args
.
lora_modules
,
prompt_adapters
=
args
.
prompt_adapters
,
prompt_adapters
=
args
.
prompt_adapters
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
)
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
state
.
openai_serving_embedding
=
OpenAIServingEmbedding
(
async_
engine_client
,
engine_client
,
model_config
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
)
)
openai_serving_tokenization
=
OpenAIServingTokenization
(
state
.
openai_serving_tokenization
=
OpenAIServingTokenization
(
async_
engine_client
,
engine_client
,
model_config
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
lora_modules
=
args
.
lora_modules
,
lora_modules
=
args
.
lora_modules
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
args
.
chat_template
,
chat_template
=
args
.
chat_template
,
)
)
app
.
root_path
=
args
.
root_path
return
app
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
logger
.
info
(
"args: %s"
,
args
)
async
with
build_async_engine_client
(
args
)
as
async_engine_client
:
temp_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
# If None, creation of the client failed and we exit.
temp_socket
.
bind
((
""
,
args
.
port
))
if
async_engine_client
is
None
:
return
def
signal_handler
(
*
_
)
->
None
:
# Interrupt server on sigterm while initializing
raise
KeyboardInterrupt
(
"terminated"
)
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
async
with
build_async_engine_client
(
args
)
as
engine_client
:
app
=
build_app
(
args
)
model_config
=
await
engine_client
.
get_model_config
()
init_app_state
(
engine_client
,
model_config
,
app
.
state
,
args
)
app
=
await
init_app
(
async_engine_client
,
args
)
temp_socket
.
close
(
)
shutdown_task
=
await
serve_http
(
shutdown_task
=
await
serve_http
(
app
,
app
,
engine
=
async_engine_client
,
host
=
args
.
host
,
host
=
args
.
host
,
port
=
args
.
port
,
port
=
args
.
port
,
log_level
=
args
.
uvicorn_log_level
,
log_level
=
args
.
uvicorn_log_level
,
...
@@ -528,4 +568,4 @@ if __name__ == "__main__":
...
@@ -528,4 +568,4 @@ if __name__ == "__main__":
parser
=
make_arg_parser
(
parser
)
parser
=
make_arg_parser
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
asyncio
.
run
(
run_server
(
args
))
uvloop
.
run
(
run_server
(
args
))
vllm/entrypoints/openai/cli_args.py
View file @
539aa992
...
@@ -31,8 +31,23 @@ class LoRAParserAction(argparse.Action):
...
@@ -31,8 +31,23 @@ class LoRAParserAction(argparse.Action):
lora_list
:
List
[
LoRAModulePath
]
=
[]
lora_list
:
List
[
LoRAModulePath
]
=
[]
for
item
in
values
:
for
item
in
values
:
name
,
path
=
item
.
split
(
'='
)
if
item
in
[
None
,
''
]:
# Skip if item is None or empty string
lora_list
.
append
(
LoRAModulePath
(
name
,
path
))
continue
if
'='
in
item
and
','
not
in
item
:
# Old format: name=path
name
,
path
=
item
.
split
(
'='
)
lora_list
.
append
(
LoRAModulePath
(
name
,
path
))
else
:
# Assume JSON format
try
:
lora_dict
=
json
.
loads
(
item
)
lora
=
LoRAModulePath
(
**
lora_dict
)
lora_list
.
append
(
lora
)
except
json
.
JSONDecodeError
:
parser
.
error
(
f
"Invalid JSON format for --lora-modules:
{
item
}
"
)
except
TypeError
as
e
:
parser
.
error
(
f
"Invalid fields for --lora-modules:
{
item
}
-
{
str
(
e
)
}
"
)
setattr
(
namespace
,
self
.
dest
,
lora_list
)
setattr
(
namespace
,
self
.
dest
,
lora_list
)
...
@@ -95,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
...
@@ -95,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default
=
None
,
default
=
None
,
nargs
=
'+'
,
nargs
=
'+'
,
action
=
LoRAParserAction
,
action
=
LoRAParserAction
,
help
=
"LoRA module configurations in the format name=path. "
help
=
"LoRA module configurations in either 'name=path' format"
"Multiple modules can be specified."
)
"or JSON format. "
"Example (old format): 'name=path' "
"Example (new format): "
"'{
\"
name
\"
:
\"
name
\"
,
\"
local_path
\"
:
\"
path
\"
, "
"
\"
base_model_name
\"
:
\"
id
\"
}'"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--prompt-adapters"
,
"--prompt-adapters"
,
type
=
nullable_str
,
type
=
nullable_str
,
...
@@ -190,6 +209,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
...
@@ -190,6 +209,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'ID numbers being printed in log.'
'ID numbers being printed in log.'
'
\n\n
Default: Unlimited'
)
'
\n\n
Default: Unlimited'
)
parser
.
add_argument
(
"--disable-fastapi-docs"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
)
return
parser
return
parser
...
...
vllm/entrypoints/openai/protocol.py
View file @
539aa992
...
@@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel):
...
@@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel):
completion_tokens
:
Optional
[
int
]
=
0
completion_tokens
:
Optional
[
int
]
=
0
class
RequestResponseMetadata
(
BaseModel
):
request_id
:
str
final_usage_info
:
Optional
[
UsageInfo
]
=
None
class
JsonSchemaResponseFormat
(
OpenAIBaseModel
):
class
JsonSchemaResponseFormat
(
OpenAIBaseModel
):
name
:
str
name
:
str
description
:
Optional
[
str
]
=
None
description
:
Optional
[
str
]
=
None
...
...
vllm/entrypoints/openai/rpc/client.py
deleted
100644 → 0
View file @
93872128
import
asyncio
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
Any
,
AsyncGenerator
,
Iterator
,
Mapping
,
Optional
from
uuid
import
uuid4
import
cloudpickle
import
zmq
import
zmq.asyncio
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
# yapf: disable
from
vllm.entrypoints.openai.rpc
import
(
RPC_REQUEST_TYPE
,
VLLM_RPC_SOCKET_LIMIT_CUTOFF
,
VLLM_RPC_SUCCESS_STR
,
VLLM_RPC_ZMQ_HWM
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_GET_DATA_TIMEOUT_MS
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
logger
=
init_logger
(
__name__
)
# Path used for inprocess proxy.
INPROC_PROXY_PATH
=
f
"inproc://
{
uuid4
()
}
"
class
RPCClientClosedError
(
Exception
):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class
AsyncEngineRPCClient
:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
The overall design mirrors the Asynchronous Client Server Pattern
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
On startup, the RPCClient:
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
via ipc, which uses unix sockets under the hood
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
- makes ROUTER socket (from_api_server) that binds to a random
inproc address, which uses memory under the hood
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
- runs a proxy in a background asyncio task between
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
Each request handled by the asyncio api_server calls generate():
- make a DEALER socket that connects to from_api_server via inproc
- send a RCPGenerateRequest to the inproc socket
- background proxy forwards the request from inproc -> ipc
- RPCServer responds to the request one token at a time over ipc
- background proxy forwards the response from ipc -> inproc
The connection looks like this:
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
Message routing is performed via identities that are managed by the
ROUTER socket. ROUTER sockets track every connection it has and
tells the caller about these. The way it tells the caller is to stick
the connection identity in front of each message received. When we
send the message via a ROUTER, we first send an identity frame.
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
for more details on connection identities.
This proxy design enables us to use a single unix socket, which
improves performance by avoiding syscalls (~5%) and avoids resource limits
such as ulimit, which defaults to 1024 on ubuntu.
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
which is required to avoid dropping messages under high load.
This is generally not advisable. However, since we are in control
of both sides of the connection + failure on either side is
catastrophic to the overall system health and memory profiling
suggests limited memory overhead relative to asyncio, we will
proceed for now.
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
for more details on high water marks.
"""
def
__init__
(
self
,
rpc_path
:
str
):
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
_data_timeout
=
VLLM_RPC_GET_DATA_TIMEOUT_MS
self
.
_errored
=
False
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit
=
self
.
context
.
get
(
zmq
.
constants
.
SOCKET_LIMIT
)
assert
isinstance
(
socket_limit
,
int
)
if
socket_limit
<
VLLM_RPC_SOCKET_LIMIT_CUTOFF
:
raise
ValueError
(
f
"Found zmq.constants.SOCKET_LIMIT=
{
socket_limit
}
, which caps "
"the number of concurrent requests vLLM can process. Launch "
"vLLM with --disable-frontend-multiprocessing and open a "
"GitHub issue so we can investigate."
)
# We only have 1 ipc connection that uses unix sockets, so
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
# not run into ulimit issues)
self
.
context
.
set
(
zmq
.
constants
.
MAX_SOCKETS
,
socket_limit
)
# IPC connection to RPC Server (uses unix sockets).
self
.
to_rpc_server
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
to_rpc_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
to_rpc_server
.
bind
(
rpc_path
)
# In process proxy to RPC Server (uses memory-based messaging).
self
.
from_api_server
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
self
.
from_api_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
from_api_server
.
bind
(
INPROC_PROXY_PATH
)
# Asyncio background task for the proxy.
self
.
proxy_in_task
=
asyncio
.
create_task
(
self
.
run_proxy
(
self
.
from_api_server
,
self
.
to_rpc_server
))
self
.
proxy_out_task
=
asyncio
.
create_task
(
self
.
run_proxy
(
self
.
to_rpc_server
,
self
.
from_api_server
))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
# mulitprocessing. This value is used uvicorn to launch
# with --limit-concurrency to return 503 when server is overloaded.
# We need 2 sockets per request - 2:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self
.
limit_concurrency
=
socket_limit
//
2
-
2
async
def
run_proxy
(
self
,
socket_from
:
Socket
,
socket_to
:
Socket
):
"""Background task that runs a proxy"""
while
True
:
frames
=
await
socket_from
.
recv_multipart
(
copy
=
False
)
await
socket_to
.
send_multipart
(
frames
,
copy
=
False
)
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
await
self
.
_wait_for_server_rpc
()
# Get the configs.
self
.
model_config
=
await
self
.
_get_model_config_rpc
()
self
.
decoding_config
=
await
self
.
_get_decoding_config_rpc
()
self
.
tracing_flag
=
await
self
.
_is_tracing_enabled_rpc
()
# Create the tokenizer group.
# TODO: refactor OAI server to avoid needing this info.
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
self
.
model_config
,
scheduler_config
=
(
await
self
.
_get_scheduler_config_rpc
()),
parallel_config
=
(
await
self
.
_get_parallel_config_rpc
()),
enable_lora
=
bool
(
await
self
.
_get_lora_config_rpc
()),
)
def
close
(
self
):
"""Destroy the ZeroMQ Context."""
# Close all sockets associated with this context and
# then terminate the context.
self
.
from_api_server
.
close
()
self
.
to_rpc_server
.
close
()
self
.
context
.
destroy
()
@
contextmanager
def
to_proxy_socket
(
self
)
->
Iterator
[
Socket
]:
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if
self
.
context
.
closed
:
raise
RPCClientClosedError
(
"The ZMQ client has already shut down"
)
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
socket
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
try
:
socket
.
connect
(
INPROC_PROXY_PATH
)
yield
socket
finally
:
socket
.
close
(
linger
=
0
)
async
def
_send_get_data_rpc_request
(
self
,
request
:
RPCUtilityRequest
,
expected_type
:
Any
,
error_message
:
str
)
->
Any
:
"""Send an RPC request that is expecting data back."""
with
self
.
to_proxy_socket
()
as
socket
:
# Ping RPCServer with a request.
await
socket
.
send_multipart
((
cloudpickle
.
dumps
(
request
),
),
copy
=
False
)
# Make sure the server responds
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
raise
TimeoutError
(
"Server didn't reply within "
f
"
{
self
.
_data_timeout
}
ms"
)
# Await the data from the Server.
frame
=
await
socket
.
recv
(
copy
=
False
)
assert
isinstance
(
frame
,
Frame
)
data
=
pickle
.
loads
(
frame
.
buffer
)
if
isinstance
(
data
,
Exception
):
# Re-raise exceptions returned by the server
raise
data
if
not
isinstance
(
data
,
expected_type
):
# LoRAConfig can be None.
if
expected_type
==
LoRAConfig
and
data
is
None
:
pass
elif
isinstance
(
data
,
Exception
):
logger
.
error
(
error_message
)
raise
data
else
:
raise
ValueError
(
error_message
)
return
data
async
def
_send_one_way_rpc_request
(
self
,
request
:
RPC_REQUEST_TYPE
,
error_message
:
str
,
socket
:
Optional
[
Socket
]
=
None
):
"""Send one-way RPC request to trigger an action."""
async
def
do_rpc_call
(
socket
:
Socket
,
request
:
RPC_REQUEST_TYPE
):
await
socket
.
send_multipart
((
cloudpickle
.
dumps
(
request
),
))
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
raise
TimeoutError
(
"Server didn't reply within "
f
"
{
self
.
_data_timeout
}
ms"
)
frame
=
await
socket
.
recv
(
copy
=
False
)
assert
isinstance
(
frame
,
Frame
)
return
pickle
.
loads
(
frame
.
buffer
)
# Make a new socket connection.
if
socket
is
None
:
with
self
.
to_proxy_socket
()
as
socket
:
response
=
await
do_rpc_call
(
socket
,
request
)
# Use existing socket connection.
else
:
response
=
await
do_rpc_call
(
socket
,
request
)
if
not
isinstance
(
response
,
str
)
or
response
!=
VLLM_RPC_SUCCESS_STR
:
if
isinstance
(
response
,
Exception
):
logger
.
error
(
error_message
)
raise
response
raise
ValueError
(
error_message
)
async
def
get_tokenizer
(
self
,
lora_request
:
LoRARequest
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
return
self
.
decoding_config
async
def
get_model_config
(
self
)
->
ModelConfig
:
return
self
.
model_config
async
def
is_tracing_enabled
(
self
)
->
bool
:
return
self
.
tracing_flag
async
def
_wait_for_server_rpc
(
self
):
"""Wait for the RPCServer to start up."""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_READY
,
error_message
=
"Unable to start RPC Server"
)
async
def
_get_model_config_rpc
(
self
)
->
ModelConfig
:
"""Get the ModelConfig object from the RPC Server"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_MODEL_CONFIG
,
expected_type
=
ModelConfig
,
error_message
=
"Could not get ModelConfig from RPC Server"
)
async
def
_get_decoding_config_rpc
(
self
)
->
DecodingConfig
:
"""Get DecodingConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_DECODING_CONFIG
,
expected_type
=
DecodingConfig
,
error_message
=
"Could not get DecodingConfig from RPC Server"
)
async
def
_get_parallel_config_rpc
(
self
)
->
ParallelConfig
:
"""Get ParallelConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
,
expected_type
=
ParallelConfig
,
error_message
=
"Could not get ParallelConfig from RPC Server"
)
async
def
_get_scheduler_config_rpc
(
self
)
->
SchedulerConfig
:
"""Get SchedulerConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
,
expected_type
=
SchedulerConfig
,
error_message
=
"Could not get SchedulerConfig from RPC Server"
)
async
def
_get_lora_config_rpc
(
self
)
->
LoRAConfig
:
"""Get LoRAConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
GET_LORA_CONFIG
,
expected_type
=
LoRAConfig
,
error_message
=
"Could not get LoRAConfig from RPC Server"
)
async
def
_is_tracing_enabled_rpc
(
self
)
->
bool
:
"""Get is_tracing_enabled flag from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
IS_TRACING_ENABLED
,
expected_type
=
bool
,
error_message
=
"Could not get is_tracing_enabled from RPC Server"
)
async
def
abort
(
self
,
request_id
:
str
):
"""Send an ABORT_REQUEST signal to the RPC Server"""
# Suppress timeouts as well.
# In cases where the server is busy processing requests and a very
# large volume of abort requests arrive, it is likely that the server
# will not be able to ack all of them in time. We have seen this when
# we abort 20k requests at once while another 2k are processing- many
# of them time out, but we see the server successfully abort all of the
# requests.
# In this case we assume that the server has received or will receive
# these abort requests, and ignore the timeout. This prevents a massive
# wall of `TimeoutError` stack traces.
with
suppress
(
RPCClientClosedError
,
TimeoutError
):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCAbortRequest
(
request_id
),
error_message
=
f
"RPCAbortRequest
{
request_id
}
failed"
)
async
def
do_log_stats
(
self
):
"""Send a DO_LOG_STATS signal to the RPC Server"""
with
suppress
(
RPCClientClosedError
):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
DO_LOG_STATS
,
error_message
=
"RPCRequest DO_LOG_STATS failed."
)
@
property
def
is_running
(
self
)
->
bool
:
return
not
self
.
_errored
@
property
def
is_stopped
(
self
)
->
bool
:
return
self
.
_errored
@
property
def
errored
(
self
)
->
bool
:
return
self
.
_errored
async
def
generate
(
self
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
finished
=
False
try
:
with
self
.
to_proxy_socket
()
as
socket
:
# Send RPCGenerateRequest to the RPCServer.
await
socket
.
send_multipart
((
cloudpickle
.
dumps
(
RPCGenerateRequest
(
inputs
=
inputs
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
)),
))
# Stream back the results from the RPC Server.
while
not
finished
:
message
=
await
socket
.
recv
(
copy
=
False
)
assert
isinstance
(
message
,
Frame
)
request_output
=
pickle
.
loads
(
message
.
buffer
)
if
isinstance
(
request_output
,
Exception
):
# On exception, check if the server is still healthy
# possibly setting the `errored` property.
if
not
self
.
_errored
:
try
:
await
self
.
check_health
(
socket
=
socket
)
except
Exception
as
e
:
self
.
_errored
=
True
logger
.
exception
(
repr
(
e
))
# NB: do before raising here so that the flag is set
# by the time the caller receives this exception
raise
request_output
finished
=
request_output
.
finished
yield
request_output
finally
:
# Request was canceled by the client.
if
not
finished
and
not
self
.
_errored
:
await
self
.
abort
(
request_id
)
async
def
check_health
(
self
,
socket
:
Optional
[
Socket
]
=
None
)
->
None
:
"""Raise if unhealthy"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_HEALTHY
,
error_message
=
"Got Unhealthy response from RPC Server"
,
socket
=
socket
)
async
def
encode
(
self
,
*
args
,
**
kwargs
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
raise
NotImplementedError
(
"Embeddings not supported with multiprocessing backend"
)
async
def
start_profile
(
self
)
->
None
:
"""Start profiling the engine"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
START_PROFILE
,
error_message
=
"RPCRequest START_PROFILE failed."
)
async
def
stop_profile
(
self
)
->
None
:
"""Stop profiling the engine"""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
STOP_PROFILE
,
error_message
=
"RPCRequest STOP_PROFILE failed."
)
vllm/entrypoints/openai/rpc/server.py
deleted
100644 → 0
View file @
93872128
import
asyncio
import
pickle
import
signal
from
typing
import
Any
,
Coroutine
,
Union
import
cloudpickle
import
uvloop
import
zmq
import
zmq.asyncio
from
typing_extensions
import
Never
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.entrypoints.openai.rpc
import
(
VLLM_RPC_SUCCESS_STR
,
VLLM_RPC_ZMQ_HWM
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
logger
=
init_logger
(
__name__
)
CONFIG_TYPE
=
Union
[
ModelConfig
,
DecodingConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
]
class
AsyncEngineRPCServer
:
def
__init__
(
self
,
async_engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
rpc_path
:
str
):
# Initialize engine first.
self
.
engine
=
AsyncLLMEngine
.
from_engine_args
(
async_engine_args
,
usage_context
=
usage_context
)
# Initialize context.
self
.
context
=
zmq
.
asyncio
.
Context
()
# Init socket.
self
.
socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
socket
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
socket
.
connect
(
rpc_path
)
def
cleanup
(
self
):
"""Cleanup all resources."""
self
.
socket
.
close
()
self
.
context
.
destroy
()
self
.
engine
.
shutdown_background_loop
()
# Clear the engine reference so that it can be GC'ed.
del
self
.
engine
async
def
get_config
(
self
,
identity
,
request
):
try
:
config
:
CONFIG_TYPE
if
request
==
RPCUtilityRequest
.
GET_MODEL_CONFIG
:
config
=
await
self
.
engine
.
get_model_config
()
elif
request
==
RPCUtilityRequest
.
GET_DECODING_CONFIG
:
config
=
await
self
.
engine
.
get_decoding_config
()
elif
request
==
RPCUtilityRequest
.
GET_LORA_CONFIG
:
config
=
await
self
.
engine
.
get_lora_config
()
elif
request
==
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
:
config
=
await
self
.
engine
.
get_scheduler_config
()
elif
request
==
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
:
config
=
await
self
.
engine
.
get_parallel_config
()
else
:
raise
ValueError
(
"Unknown Config Request: %s"
,
request
)
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
config
)),
copy
=
False
)
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
is_tracing_enabled
(
self
,
identity
):
"""Send the is_tracing_enabled flag"""
tracing_flag
=
await
self
.
engine
.
is_tracing_enabled
()
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
tracing_flag
)))
async
def
do_log_stats
(
self
,
identity
):
"""Log stats and confirm success."""
await
self
.
engine
.
do_log_stats
()
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)))
async
def
is_server_ready
(
self
,
identity
):
"""Notify the client that we are ready."""
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)))
async
def
abort
(
self
,
identity
,
request
:
RPCAbortRequest
):
"""Abort request and notify the client of success."""
try
:
# Abort the request in the llm engine.
await
self
.
engine
.
abort
(
request
.
request_id
)
result
:
Union
[
str
,
Exception
]
=
VLLM_RPC_SUCCESS_STR
except
Exception
as
e
:
result
=
e
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
result
)))
async
def
generate
(
self
,
identity
,
generate_request
:
RPCGenerateRequest
):
try
:
results_generator
=
self
.
engine
.
generate
(
generate_request
.
inputs
,
sampling_params
=
generate_request
.
sampling_params
,
request_id
=
generate_request
.
request_id
,
lora_request
=
generate_request
.
lora_request
,
trace_headers
=
generate_request
.
trace_headers
,
prompt_adapter_request
=
generate_request
.
prompt_adapter_request
)
async
for
request_output
in
results_generator
:
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
request_output
)),
copy
=
False
)
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
check_health
(
self
,
identity
):
try
:
await
self
.
engine
.
check_health
()
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)))
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
start_profile
(
self
,
identity
):
logger
.
info
(
"Starting profiler..."
)
await
self
.
engine
.
start_profile
()
logger
.
info
(
"Profiler started."
)
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
))
async
def
stop_profile
(
self
,
identity
):
logger
.
info
(
"Stopping profiler..."
)
await
self
.
engine
.
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
))
def
_make_handler_coro
(
self
,
identity
,
message
:
Frame
)
->
Coroutine
[
Any
,
Any
,
Never
]:
"""Route the zmq message to the handler coroutine."""
request
=
cloudpickle
.
loads
(
message
.
buffer
)
if
isinstance
(
request
,
RPCGenerateRequest
):
return
self
.
generate
(
identity
,
request
)
elif
isinstance
(
request
,
RPCAbortRequest
):
return
self
.
abort
(
identity
,
request
)
elif
isinstance
(
request
,
RPCUtilityRequest
):
if
request
in
[
RPCUtilityRequest
.
GET_MODEL_CONFIG
,
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
,
RPCUtilityRequest
.
GET_DECODING_CONFIG
,
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
,
RPCUtilityRequest
.
GET_LORA_CONFIG
]:
return
self
.
get_config
(
identity
,
request
)
elif
request
==
RPCUtilityRequest
.
DO_LOG_STATS
:
return
self
.
do_log_stats
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_SERVER_READY
:
return
self
.
is_server_ready
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_SERVER_HEALTHY
:
return
self
.
check_health
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_TRACING_ENABLED
:
return
self
.
is_tracing_enabled
(
identity
)
elif
request
==
RPCUtilityRequest
.
START_PROFILE
:
return
self
.
start_profile
(
identity
)
elif
request
==
RPCUtilityRequest
.
STOP_PROFILE
:
return
self
.
stop_profile
(
identity
)
else
:
raise
ValueError
(
f
"Unknown RPCUtilityRequest type:
{
request
}
"
)
else
:
raise
ValueError
(
f
"Unknown RPCRequest type:
{
request
}
"
)
async
def
run_server_loop
(
self
):
"""Inner RPC Server Loop"""
running_tasks
=
set
()
while
True
:
# Wait for a request.
identity
,
message
=
await
self
.
socket
.
recv_multipart
(
copy
=
False
)
# Process the request async.
task
=
asyncio
.
create_task
(
self
.
_make_handler_coro
(
identity
,
message
))
# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
running_tasks
.
add
(
task
)
task
.
add_done_callback
(
running_tasks
.
discard
)
async
def
run_server
(
server
:
AsyncEngineRPCServer
):
# Put the server task into the asyncio loop.
loop
=
asyncio
.
get_running_loop
()
server_task
=
loop
.
create_task
(
server
.
run_server_loop
())
# Interruption handling.
def
signal_handler
()
->
None
:
# Kill the server on interrupt / terminate
server_task
.
cancel
()
loop
.
add_signal_handler
(
signal
.
SIGINT
,
signal_handler
)
loop
.
add_signal_handler
(
signal
.
SIGTERM
,
signal_handler
)
try
:
await
server_task
except
asyncio
.
CancelledError
:
logger
.
info
(
"vLLM ZMQ RPC Server was interrupted."
)
finally
:
# Clean up all resources.
server
.
cleanup
()
def
run_rpc_server
(
async_engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
rpc_path
:
str
):
server
=
AsyncEngineRPCServer
(
async_engine_args
,
usage_context
,
rpc_path
)
uvloop
.
run
(
run_server
(
server
))
vllm/entrypoints/openai/run_batch.py
View file @
539aa992
...
@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
...
@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
# yapf: enable
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
...
@@ -196,6 +197,10 @@ async def main(args):
...
@@ -196,6 +197,10 @@ async def main(args):
engine_args
,
usage_context
=
UsageContext
.
OPENAI_BATCH_RUNNER
)
engine_args
,
usage_context
=
UsageContext
.
OPENAI_BATCH_RUNNER
)
model_config
=
await
engine
.
get_model_config
()
model_config
=
await
engine
.
get_model_config
()
base_model_paths
=
[
BaseModelPath
(
name
=
name
,
model_path
=
args
.
model
)
for
name
in
served_model_names
]
if
args
.
disable_log_requests
:
if
args
.
disable_log_requests
:
request_logger
=
None
request_logger
=
None
...
@@ -206,7 +211,7 @@ async def main(args):
...
@@ -206,7 +211,7 @@ async def main(args):
openai_serving_chat
=
OpenAIServingChat
(
openai_serving_chat
=
OpenAIServingChat
(
engine
,
engine
,
model_config
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
args
.
response_role
,
args
.
response_role
,
lora_modules
=
None
,
lora_modules
=
None
,
prompt_adapters
=
None
,
prompt_adapters
=
None
,
...
@@ -216,7 +221,7 @@ async def main(args):
...
@@ -216,7 +221,7 @@ async def main(args):
openai_serving_embedding
=
OpenAIServingEmbedding
(
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
engine
,
model_config
,
model_config
,
se
rved
_model_
name
s
,
ba
se_model_
path
s
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
)
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
539aa992
...
@@ -9,7 +9,7 @@ from typing import Union
...
@@ -9,7 +9,7 @@ from typing import Union
from
fastapi
import
Request
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
Async
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
apply_hf_chat_template
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
apply_mistral_chat_template
,
...
@@ -22,8 +22,10 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -22,8 +22,10 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaFunctionCall
,
DeltaMessage
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ErrorResponse
,
FunctionCall
,
ToolCall
,
UsageInfo
)
DeltaToolCall
,
ErrorResponse
,
FunctionCall
,
RequestResponseMetadata
,
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
ToolCall
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
LoRAModulePath
,
OpenAIServing
,
OpenAIServing
,
PromptAdapterPath
,
PromptAdapterPath
,
TextTokensPrompt
)
TextTokensPrompt
)
...
@@ -45,9 +47,9 @@ logger = init_logger(__name__)
...
@@ -45,9 +47,9 @@ logger = init_logger(__name__)
class
OpenAIServingChat
(
OpenAIServing
):
class
OpenAIServingChat
(
OpenAIServing
):
def
__init__
(
self
,
def
__init__
(
self
,
async_
engine_client
:
Async
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
se
rved
_model_
name
s
:
List
[
str
],
ba
se_model_
path
s
:
List
[
BaseModelPath
],
response_role
:
str
,
response_role
:
str
,
*
,
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
...
@@ -57,9 +59,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -57,9 +59,9 @@ class OpenAIServingChat(OpenAIServing):
return_tokens_as_token_ids
:
bool
=
False
,
return_tokens_as_token_ids
:
bool
=
False
,
enable_auto_tools
:
bool
=
False
,
enable_auto_tools
:
bool
=
False
,
tool_parser
:
Optional
[
str
]
=
None
):
tool_parser
:
Optional
[
str
]
=
None
):
super
().
__init__
(
async_
engine_client
=
async_
engine_client
,
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
model_config
=
model_config
,
se
rved
_model_
names
=
served
_model_
name
s
,
ba
se_model_
paths
=
base
_model_
path
s
,
lora_modules
=
lora_modules
,
lora_modules
=
lora_modules
,
prompt_adapters
=
prompt_adapters
,
prompt_adapters
=
prompt_adapters
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
...
@@ -105,6 +107,12 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -105,6 +107,12 @@ class OpenAIServingChat(OpenAIServing):
logger
.
error
(
"Error with model %s"
,
error_check_ret
)
logger
.
error
(
"Error with model %s"
,
error_check_ret
)
return
error_check_ret
return
error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if
self
.
engine_client
.
errored
:
raise
self
.
engine_client
.
dead_error
try
:
try
:
(
(
lora_request
,
lora_request
,
...
@@ -112,8 +120,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -112,8 +120,7 @@ class OpenAIServingChat(OpenAIServing):
)
=
self
.
_maybe_get_adapters
(
request
)
)
=
self
.
_maybe_get_adapters
(
request
)
model_config
=
self
.
model_config
model_config
=
self
.
model_config
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
lora_request
)
conversation
,
mm_data_future
=
parse_chat_messages_futures
(
conversation
,
mm_data_future
=
parse_chat_messages_futures
(
request
.
messages
,
model_config
,
tokenizer
)
request
.
messages
,
model_config
,
tokenizer
)
...
@@ -123,7 +130,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -123,7 +130,8 @@ class OpenAIServingChat(OpenAIServing):
]
]
prompt
:
Union
[
str
,
List
[
int
]]
prompt
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
is_mistral_tokenizer
=
isinstance
(
tokenizer
,
MistralTokenizer
)
if
is_mistral_tokenizer
:
prompt
=
apply_mistral_chat_template
(
prompt
=
apply_mistral_chat_template
(
tokenizer
,
tokenizer
,
messages
=
request
.
messages
,
messages
=
request
.
messages
,
...
@@ -159,15 +167,20 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -159,15 +167,20 @@ class OpenAIServingChat(OpenAIServing):
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"tool_choice =
\"
required
\"
is not supported!"
)
"tool_choice =
\"
required
\"
is not supported!"
)
# "auto" tools requires --enable-auto-tool-choice
if
not
is_mistral_tokenizer
and
request
.
tool_choice
==
"auto"
and
not
(
# and --tool-call-parser
if
request
.
tool_choice
==
"auto"
and
not
(
self
.
enable_auto_tools
and
self
.
tool_parser
is
not
None
):
self
.
enable_auto_tools
and
self
.
tool_parser
is
not
None
):
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"
\"
auto
\"
tool choice requires "
"
\"
auto
\"
tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set"
)
"--enable-auto-tool-choice and --tool-call-parser to be set"
)
request_id
=
f
"chat-
{
random_uuid
()
}
"
request_id
=
f
"chat-
{
random_uuid
()
}
"
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
if
raw_request
:
raw_request
.
state
.
request_metadata
=
request_metadata
try
:
try
:
guided_decode_logits_processor
=
(
guided_decode_logits_processor
=
(
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
...
@@ -206,8 +219,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -206,8 +219,8 @@ class OpenAIServingChat(OpenAIServing):
if
mm_data
is
not
None
:
if
mm_data
is
not
None
:
engine_inputs
[
"multi_modal_data"
]
=
mm_data
engine_inputs
[
"multi_modal_data"
]
=
mm_data
is_tracing_enabled
=
(
is_tracing_enabled
=
(
await
await
self
.
async_
engine_client
.
is_tracing_enabled
())
self
.
engine_client
.
is_tracing_enabled
())
trace_headers
=
None
trace_headers
=
None
if
is_tracing_enabled
and
raw_request
:
if
is_tracing_enabled
and
raw_request
:
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
...
@@ -215,7 +228,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -215,7 +228,7 @@ class OpenAIServingChat(OpenAIServing):
and
contains_trace_headers
(
raw_request
.
headers
)):
and
contains_trace_headers
(
raw_request
.
headers
)):
log_tracing_disabled_warning
()
log_tracing_disabled_warning
()
result_generator
=
self
.
async_
engine_client
.
generate
(
result_generator
=
self
.
engine_client
.
generate
(
engine_inputs
,
engine_inputs
,
sampling_params
,
sampling_params
,
request_id
,
request_id
,
...
@@ -234,11 +247,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -234,11 +247,13 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
# Streaming response
if
request
.
stream
:
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
return
self
.
chat_completion_stream_generator
(
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
,
request_metadata
)
try
:
try
:
return
await
self
.
chat_completion_full_generator
(
return
await
self
.
chat_completion_full_generator
(
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
,
request_metadata
)
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
@@ -255,8 +270,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -255,8 +270,9 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
request_id
:
str
,
conversation
:
List
[
ConversationMessage
],
conversation
:
List
[
ConversationMessage
],
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
model_name
=
self
.
se
rved
_model_name
s
[
0
]
model_name
=
self
.
ba
se_model_
paths
[
0
].
name
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
chunk_object_type
:
Final
=
"chat.completion.chunk"
chunk_object_type
:
Final
=
"chat.completion.chunk"
first_iteration
=
True
first_iteration
=
True
...
@@ -293,6 +309,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -293,6 +309,8 @@ class OpenAIServingChat(OpenAIServing):
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
if
res
.
prompt_token_ids
is
not
None
:
if
res
.
prompt_token_ids
is
not
None
:
num_prompt_tokens
=
len
(
res
.
prompt_token_ids
)
num_prompt_tokens
=
len
(
res
.
prompt_token_ids
)
if
res
.
encoder_prompt_token_ids
is
not
None
:
num_prompt_tokens
+=
len
(
res
.
encoder_prompt_token_ids
)
# We need to do it here, because if there are exceptions in
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# the result_generator, it needs to be sent as the FIRST
...
@@ -573,6 +591,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -573,6 +591,13 @@ class OpenAIServingChat(OpenAIServing):
exclude_unset
=
True
,
exclude_none
=
True
))
exclude_unset
=
True
,
exclude_none
=
True
))
yield
f
"data:
{
final_usage_data
}
\n\n
"
yield
f
"data:
{
final_usage_data
}
\n\n
"
# report to FastAPI middleware aggregate usage across all choices
num_completion_tokens
=
sum
(
previous_num_tokens
)
request_metadata
.
final_usage_info
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_completion_tokens
,
total_tokens
=
num_prompt_tokens
+
num_completion_tokens
)
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
logger
.
error
(
"error in chat completion stream generator: %s"
,
e
)
logger
.
error
(
"error in chat completion stream generator: %s"
,
e
)
...
@@ -588,9 +613,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -588,9 +613,10 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
request_id
:
str
,
conversation
:
List
[
ConversationMessage
],
conversation
:
List
[
ConversationMessage
],
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
Union
[
ErrorResponse
,
ChatCompletionResponse
]:
)
->
Union
[
ErrorResponse
,
ChatCompletionResponse
]:
model_name
=
self
.
se
rved
_model_name
s
[
0
]
model_name
=
self
.
ba
se_model_
paths
[
0
].
name
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
final_res
:
Optional
[
RequestOutput
]
=
None
final_res
:
Optional
[
RequestOutput
]
=
None
...
@@ -707,6 +733,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -707,6 +733,9 @@ class OpenAIServingChat(OpenAIServing):
completion_tokens
=
num_generated_tokens
,
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
)
request_metadata
.
final_usage_info
=
usage
response
=
ChatCompletionResponse
(
response
=
ChatCompletionResponse
(
id
=
request_id
,
id
=
request_id
,
created
=
created_time
,
created
=
created_time
,
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment