Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0640f227
Commit
0640f227
authored
Sep 09, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.0' into v0.6.0-dev
parents
82f1ffdf
32e7db25
Changes
335
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1820 additions
and
578 deletions
+1820
-578
vllm/core/scheduler.py
vllm/core/scheduler.py
+159
-29
vllm/distributed/device_communicators/custom_all_reduce_utils.py
...stributed/device_communicators/custom_all_reduce_utils.py
+25
-16
vllm/distributed/device_communicators/tpu_communicator.py
vllm/distributed/device_communicators/tpu_communicator.py
+25
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+34
-10
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+99
-90
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+435
-74
vllm/engine/output_processor/interfaces.py
vllm/engine/output_processor/interfaces.py
+5
-8
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+68
-36
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+67
-32
vllm/engine/output_processor/util.py
vllm/engine/output_processor/util.py
+2
-1
vllm/engine/protocol.py
vllm/engine/protocol.py
+1
-1
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+325
-110
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+37
-15
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+36
-17
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+18
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+125
-14
vllm/entrypoints/openai/rpc/client.py
vllm/entrypoints/openai/rpc/client.py
+42
-44
vllm/entrypoints/openai/rpc/server.py
vllm/entrypoints/openai/rpc/server.py
+27
-21
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+29
-5
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+261
-53
No files found.
vllm/core/scheduler.py
View file @
0640f227
...
...
@@ -4,7 +4,8 @@ import random
import
time
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
(
Callable
,
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
...
...
@@ -220,10 +221,10 @@ class SchedulerSwappedInOutputs:
"""
# Selected sequences that are going to be swapped in and is in a
# decoding phase.
decode_seq_groups
:
List
[
SequenceGroup
]
decode_seq_groups
:
List
[
Scheduled
SequenceGroup
]
# Selected sequences that are going to be swapped in and in a prefill
# phase. I.e., it means the prefill has been chunked.
prefill_seq_groups
:
List
[
SequenceGroup
]
prefill_seq_groups
:
List
[
Scheduled
SequenceGroup
]
# The blocks to swap in.
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]]
# The blocks to copy.
...
...
@@ -253,7 +254,7 @@ class SchedulerPrefillOutputs:
to be recomputed from scratch.
"""
# Selected sequences for prefill.
seq_groups
:
List
[
SequenceGroup
]
seq_groups
:
List
[
Scheduled
SequenceGroup
]
# Ignored sequence groups.
ignored_seq_groups
:
List
[
SequenceGroup
]
num_lookahead_slots
:
int
...
...
@@ -288,7 +289,9 @@ def scheduler_running_outputs_builder():
def
scheduled_seq_group_builder
():
return
ScheduledSequenceGroup
(
seq_group
=
None
,
token_chunk_size
=
0
)
return
ScheduledSequenceGroup
(
SequenceGroup
(
""
,
[],
-
1
),
token_chunk_size
=
0
)
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
class
Scheduler
:
...
...
@@ -299,6 +302,7 @@ class Scheduler:
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
pipeline_parallel_size
:
int
=
1
,
output_proc_callback
:
Optional
[
Callable
]
=
None
,
)
->
None
:
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
...
...
@@ -364,10 +368,36 @@ class Scheduler:
self
.
num_cumulative_preemption
:
int
=
0
# Used to cache python objects
self
.
_scheduler_running_outputs_cache
:
PyObjectCache
=
PyObjectCache
(
scheduler_running_outputs_builder
)
self
.
_scheduled_seq_group_cache
:
PyObjectCache
=
PyObjectCache
(
scheduled_seq_group_builder
)
self
.
_seq_group_metadata_cache
:
List
[
PyObjectCache
]
=
[]
self
.
_scheduler_running_outputs_cache
:
List
[
PyObjectCache
]
=
[]
self
.
_scheduled_seq_group_cache
:
List
[
PyObjectCache
]
=
[]
# For async output processing, we need to swap cache buffers between
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
self
.
output_proc_callback
=
output_proc_callback
self
.
use_async_output_proc
=
self
.
output_proc_callback
is
not
None
self
.
num_cache_iters
=
2
if
self
.
use_async_output_proc
else
1
self
.
cache_id
=
0
for
i
in
range
(
self
.
num_cache_iters
):
self
.
_seq_group_metadata_cache
.
append
(
PyObjectCache
(
seq_group_metadata_builder
))
self
.
_scheduler_running_outputs_cache
.
append
(
PyObjectCache
(
scheduler_running_outputs_builder
))
self
.
_scheduled_seq_group_cache
.
append
(
PyObjectCache
(
scheduled_seq_group_builder
))
# For async postprocessor, the extra decode run cannot be done
# when the request reaches max_model_len. In this case, the request
# will be stopped during schedule() call and added to this stop list
# for processing and deallocation by the free_finished_seq_groups()
self
.
_async_stopped
:
List
[
SequenceGroup
]
=
[]
@
property
def
next_cache_id
(
self
):
return
(
self
.
cache_id
+
1
)
%
self
.
num_cache_iters
@
property
def
lora_enabled
(
self
)
->
bool
:
...
...
@@ -483,7 +513,7 @@ class Scheduler:
SchedulerRunningOutputs.
"""
ret
:
SchedulerRunningOutputs
=
\
self
.
_scheduler_running_outputs_cache
.
get_object
()
self
.
_scheduler_running_outputs_cache
[
self
.
cache_id
]
.
get_object
()
ret
.
blocks_to_swap_out
.
clear
()
ret
.
blocks_to_copy
.
clear
()
ret
.
decode_seq_groups
.
clear
()
...
...
@@ -510,8 +540,12 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
running_queue
=
self
.
running
# Store original running requests for the case of async + preemption
if
self
.
use_async_output_proc
:
orig_running
=
self
.
running
.
copy
()
running_queue
=
self
.
running
assert
len
(
self
.
_async_stopped
)
==
0
while
running_queue
:
seq_group
=
running_queue
[
0
]
num_running_tokens
=
self
.
_get_num_new_tokens
(
...
...
@@ -521,6 +555,28 @@ class Scheduler:
break
running_queue
.
popleft
()
# With async postprocessor, an extra decode run is done
# to process the final tokens. The check below avoids this extra
# decode run when the model max len is reached, in order to avoid
# a memory overflow.
if
self
.
use_async_output_proc
and
seq_group
.
seqs
[
0
].
get_len
(
)
>
self
.
scheduler_config
.
max_model_len
:
self
.
_async_stopped
.
append
(
seq_group
)
continue
# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if
self
.
use_async_output_proc
and
not
self
.
_can_append_slots
(
seq_group
):
tmp
=
self
.
running
self
.
running
=
orig_running
assert
self
.
output_proc_callback
is
not
None
self
.
output_proc_callback
()
self
.
running
=
tmp
while
not
self
.
_can_append_slots
(
seq_group
):
budget
.
subtract_num_batched_tokens
(
seq_group
.
request_id
,
num_running_tokens
)
...
...
@@ -556,7 +612,7 @@ class Scheduler:
is_prefill
=
seq_group
.
is_prefill
()
scheduled_seq_group
:
ScheduledSequenceGroup
=
\
self
.
_scheduled_seq_group_cache
.
get_object
()
self
.
_scheduled_seq_group_cache
[
self
.
cache_id
]
.
get_object
()
scheduled_seq_group
.
seq_group
=
seq_group
if
is_prefill
:
scheduled_seq_group
.
token_chunk_size
=
num_running_tokens
...
...
@@ -579,8 +635,8 @@ class Scheduler:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
curr_loras
.
add
(
seq_group
.
lora_int_id
)
self
.
_scheduler_running_outputs_cache
.
reset
()
self
.
_scheduled_seq_group_cache
.
reset
()
self
.
_scheduler_running_outputs_cache
[
self
.
next_cache_id
]
.
reset
()
self
.
_scheduled_seq_group_cache
[
self
.
next_cache_id
]
.
reset
()
return
ret
...
...
@@ -737,7 +793,7 @@ class Scheduler:
SchedulerPrefillOutputs.
"""
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
seq_groups
:
List
[
SequenceGroup
]
=
[]
seq_groups
:
List
[
Scheduled
SequenceGroup
]
=
[]
waiting_queue
=
self
.
waiting
...
...
@@ -971,16 +1027,21 @@ class Scheduler:
# Update waiting requests.
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
prefill_seq_groups
])
# By default, vLLM scheduler prioritizes prefills.
# Once chunked prefill is enabled,
# the policy is changed to prioritize decode requests.
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
decode_seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
prefill_seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
prefill_seq_groups
])
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
# Update swapped requests.
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
return
SchedulerOutputs
(
...
...
@@ -1031,17 +1092,28 @@ class Scheduler:
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
),
)
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]:
def
_allow_async_output_proc
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
no_beam_search
=
seq_group
.
sampling_params
is
None
or
(
seq_group
.
sampling_params
.
best_of
==
1
and
not
seq_group
.
sampling_params
.
use_beam_search
)
return
no_beam_search
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
bool
]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_start_time
=
time
.
perf_counter
()
scheduler_outputs
=
self
.
_schedule
()
now
=
time
.
time
()
if
not
self
.
cache_config
.
enable_prefix_caching
:
common_computed_block_nums
=
[]
allow_async_output_proc
:
bool
=
self
.
use_async_output_proc
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
i
,
scheduled_seq_group
in
enumerate
(
...
...
@@ -1050,6 +1122,11 @@ class Scheduler:
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
seq_group
.
maybe_set_first_scheduled_time
(
now
)
seq_group_metadata
=
self
.
_seq_group_metadata_cache
[
self
.
cache_id
].
get_object
()
seq_group_metadata
.
seq_data
.
clear
()
seq_group_metadata
.
block_tables
.
clear
()
# seq_id -> SequenceData
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
# seq_id -> physical block numbers
...
...
@@ -1057,7 +1134,9 @@ class Scheduler:
if
seq_group
.
is_encoder_decoder
():
# Encoder associated with SequenceGroup
encoder_seq_data
=
seq_group
.
get_encoder_seq
().
data
encoder_seq
=
seq_group
.
get_encoder_seq
()
assert
encoder_seq
is
not
None
encoder_seq_data
=
encoder_seq
.
data
# Block table for cross-attention
# Also managed at SequenceGroup level
cross_block_table
=
self
.
block_manager
.
get_cross_block_table
(
...
...
@@ -1139,13 +1218,20 @@ class Scheduler:
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
if
allow_async_output_proc
:
allow_async_output_proc
=
self
.
_allow_async_output_proc
(
seq_group
)
# Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
# will crash the vLLM instance / will not retry.
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
self
.
block_manager
.
mark_blocks_as_computed
(
scheduled_seq_group
.
seq_group
)
scheduled_seq_group
.
seq_group
,
scheduled_seq_group
.
token_chunk_size
)
self
.
_seq_group_metadata_cache
[
self
.
next_cache_id
].
reset
()
scheduler_time
=
time
.
perf_counter
()
-
scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
...
...
@@ -1158,7 +1244,12 @@ class Scheduler:
else
:
seq_group
.
metrics
.
scheduler_time
=
scheduler_time
return
seq_group_metadata_list
,
scheduler_outputs
# Move to next cache (if exists)
self
.
cache_id
=
self
.
next_cache_id
# Return results
return
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
def
fork_seq
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
self
.
block_manager
.
fork
(
parent_seq
,
child_seq
)
...
...
@@ -1167,6 +1258,12 @@ class Scheduler:
"""Free a sequence from a block table."""
self
.
block_manager
.
free
(
seq
)
def
_free_finished_seqs
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
"""Free finished seqs in a sequence group."""
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
is_finished
():
self
.
free_seq
(
seq
)
def
free_finished_seq_groups
(
self
)
->
None
:
remaining
:
Deque
[
SequenceGroup
]
=
deque
()
for
seq_group
in
self
.
running
:
...
...
@@ -1179,8 +1276,24 @@ class Scheduler:
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
else
:
remaining
.
append
(
seq_group
)
# Free finished seqs
self
.
_free_finished_seqs
(
seq_group
)
self
.
running
=
remaining
# Handle async stopped sequence groups
# (ones that reached max model len)
if
self
.
_async_stopped
:
for
seq_group
in
self
.
_async_stopped
:
self
.
_free_seq_group_cross_attn_blocks
(
seq_group
)
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
# Free finished seqs
self
.
_free_finished_seqs
(
seq_group
)
self
.
_async_stopped
.
clear
()
def
_allocate_and_set_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
self
.
block_manager
.
allocate
(
seq_group
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
...
...
@@ -1347,10 +1460,27 @@ class Scheduler:
for
seq
in
seqs
:
num_new_tokens
+=
seq
.
get_num_new_tokens
()
assert
num_new_tokens
>
0
# Chunk if a running request cannot fit in.
# If number of seq > 1, it means it is doing beam search
in a
# decode phase. Do not chunk
in that case
.
# Chunk if a running request cannot fit in
the given budget
.
# If number of seq > 1, it means it is doing beam search
#
in a
decode phase. Do not chunk.
if
enable_chunking
and
len
(
seqs
)
==
1
:
num_new_tokens
=
min
(
num_new_tokens
,
budget
.
remaining_token_budget
())
remaining_token_budget
=
budget
.
remaining_token_budget
()
if
self
.
cache_config
.
enable_prefix_caching
:
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block size
# to avoid partial block matching.
block_size
=
self
.
cache_config
.
block_size
reminder
=
budget
.
token_budget
%
block_size
if
reminder
!=
0
:
raise
ValueError
(
"When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f
"(
{
budget
.
token_budget
}
) % block_size "
f
"(
{
block_size
}
) =
{
reminder
}
"
)
if
remaining_token_budget
<
num_new_tokens
:
num_new_tokens
=
(
remaining_token_budget
//
block_size
)
*
block_size
else
:
num_new_tokens
=
min
(
num_new_tokens
,
remaining_token_budget
)
return
num_new_tokens
vllm/distributed/device_communicators/custom_all_reduce_utils.py
View file @
0640f227
...
...
@@ -4,6 +4,7 @@ import os
import
pickle
import
subprocess
import
sys
import
tempfile
from
itertools
import
product
from
typing
import
Dict
,
List
,
Optional
,
Sequence
...
...
@@ -211,20 +212,27 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# However, `can_actually_p2p` requires spawn method.
# The fix is, we use `subprocess` to call the function,
# where we have `if __name__ == "__main__":` in this file.
input_bytes
=
pickle
.
dumps
((
batch_src
,
batch_tgt
))
returned
=
subprocess
.
run
([
sys
.
executable
,
__file__
],
input
=
input_bytes
,
capture_output
=
True
)
# check if the subprocess is successful
try
:
returned
.
check_returncode
()
except
Exception
as
e
:
# wrap raised exception to provide more information
raise
RuntimeError
(
f
"Error happened when batch testing "
f
"peer-to-peer access from
{
batch_src
}
to
{
batch_tgt
}
:
\n
"
f
"
{
returned
.
stderr
.
decode
()
}
"
)
from
e
result
=
pickle
.
loads
(
returned
.
stdout
)
# use a temporary file to store the result
# we don't use the output of the subprocess directly,
# because the subprocess might produce logging output
with
tempfile
.
NamedTemporaryFile
()
as
output_file
:
input_bytes
=
pickle
.
dumps
(
(
batch_src
,
batch_tgt
,
output_file
.
name
))
returned
=
subprocess
.
run
([
sys
.
executable
,
__file__
],
input
=
input_bytes
,
capture_output
=
True
)
# check if the subprocess is successful
try
:
returned
.
check_returncode
()
except
Exception
as
e
:
# wrap raised exception to provide more information
raise
RuntimeError
(
f
"Error happened when batch testing "
f
"peer-to-peer access from
{
batch_src
}
to
{
batch_tgt
}
:
\n
"
f
"
{
returned
.
stderr
.
decode
()
}
"
)
from
e
with
open
(
output_file
.
name
,
"rb"
)
as
f
:
result
=
pickle
.
load
(
f
)
for
_i
,
_j
,
r
in
zip
(
batch_src
,
batch_tgt
,
result
):
cache
[
f
"
{
_i
}
->
{
_j
}
"
]
=
r
with
open
(
path
,
"w"
)
as
f
:
...
...
@@ -241,6 +249,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
__all__
=
[
"gpu_p2p_access_check"
]
if
__name__
==
"__main__"
:
batch_src
,
batch_tgt
=
pickle
.
loads
(
sys
.
stdin
.
buffer
.
read
())
batch_src
,
batch_tgt
,
output_file
=
pickle
.
loads
(
sys
.
stdin
.
buffer
.
read
())
result
=
can_actually_p2p
(
batch_src
,
batch_tgt
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
result
))
with
open
(
output_file
,
"wb"
)
as
f
:
f
.
write
(
pickle
.
dumps
(
result
))
vllm/distributed/device_communicators/tpu_communicator.py
View file @
0640f227
import
os
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
...
...
@@ -5,11 +7,12 @@ from torch.distributed import ProcessGroup
from
vllm.platforms
import
current_platform
if
current_platform
.
is_tpu
():
import
ray
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
from
torch_xla._internal
import
pjrt
from
vllm.executor
import
ray_utils
class
TpuCommunicator
:
...
...
@@ -24,9 +27,29 @@ class TpuCommunicator:
# be simply calculated as follows.
global_rank
=
dist
.
get_rank
(
group
)
global_world_size
=
dist
.
get_world_size
(
group
)
num_nodes
=
len
(
ray
.
nodes
())
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU
# nodes is computed by the total number of TPUs divided by the
# number of TPU accelerators per node, to account for clusters
# with both CPUs and TPUs.
num_nodes
=
ray_utils
.
get_num_tpu_nodes
()
num_nodes_in_pg
=
ray_utils
.
get_num_nodes_in_placement_group
()
if
num_nodes_in_pg
>
0
:
num_nodes
=
num_nodes_in_pg
local_world_size
=
global_world_size
//
num_nodes
local_rank
=
global_rank
%
local_world_size
# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
# chip is actually visible. Otherwise the TPU driver will fail to
# initialize because the number of devices would be different from
# the number of visible worker addresses.
os
.
environ
[
"CLOUD_TPU_TASK_ID"
]
=
str
(
global_rank
)
os
.
environ
[
"TPU_VISIBLE_CHIPS"
]
=
str
(
local_rank
)
pjrt
.
initialize_multiprocess
(
local_rank
,
local_world_size
)
xr
.
_init_world_size_ordinal
()
...
...
vllm/engine/arg_utils.py
View file @
0640f227
...
...
@@ -2,8 +2,8 @@ import argparse
import
dataclasses
import
json
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
import
torch
...
...
@@ -16,6 +16,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.utils
import
FlexibleArgumentParser
if
TYPE_CHECKING
:
...
...
@@ -147,6 +148,8 @@ class EngineArgs:
otlp_traces_endpoint
:
Optional
[
str
]
=
None
collect_detailed_traces
:
Optional
[
str
]
=
None
disable_async_output_proc
:
bool
=
False
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
...
...
@@ -197,10 +200,11 @@ class EngineArgs:
'--tokenizer-mode'
,
type
=
str
,
default
=
EngineArgs
.
tokenizer_mode
,
choices
=
[
'auto'
,
'slow'
],
choices
=
[
'auto'
,
'slow'
,
'mistral'
],
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
'fast tokenizer if available.
\n
* "slow" will '
'always use the slow tokenizer.'
)
'always use the slow tokenizer.
\n
* '
'"mistral" will always use the `mistral_common` tokenizer.'
)
parser
.
add_argument
(
'--trust-remote-code'
,
action
=
'store_true'
,
help
=
'Trust remote code from huggingface.'
)
...
...
@@ -317,9 +321,10 @@ class EngineArgs:
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
EngineArgs
.
block_size
,
choices
=
[
8
,
16
,
32
,
128
,
256
,
512
,
1024
,
2048
],
choices
=
[
8
,
16
,
32
],
help
=
'Token block size for contiguous chunks of '
'tokens.'
)
'tokens. This is ignored on neuron devices and '
'set to max-model-len'
)
parser
.
add_argument
(
'--enable-prefix-caching'
,
action
=
'store_true'
,
...
...
@@ -732,6 +737,22 @@ class EngineArgs:
"modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact."
)
parser
.
add_argument
(
'--disable-async-output-proc'
,
action
=
'store_true'
,
default
=
EngineArgs
.
disable_async_output_proc
,
help
=
"Disable async output processing. This may result in "
"lower performance."
)
parser
.
add_argument
(
'--override-neuron-config'
,
type
=
lambda
configs
:
{
str
(
key
):
value
for
key
,
value
in
(
config
.
split
(
':'
)
for
config
in
configs
.
split
(
','
))
},
default
=
None
,
help
=
"override or set neuron device configuration."
)
return
parser
@
classmethod
...
...
@@ -742,9 +763,9 @@ class EngineArgs:
engine_args
=
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
return
engine_args
def
create_engine_config
(
self
,
)
->
EngineConfig
:
def
create_engine_config
(
self
)
->
EngineConfig
:
# gguf file needs a specific model loader and doesn't use hf_repo
if
self
.
model
.
endswith
(
".gguf"
):
if
check_gguf_file
(
self
.
model
):
self
.
quantization
=
self
.
load_format
=
"gguf"
# bitsandbytes quantization needs a specific model loader
...
...
@@ -791,9 +812,11 @@ class EngineArgs:
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
served_model_name
=
self
.
served_model_name
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
)
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
override_neuron_config
=
self
.
override_neuron_config
)
cache_config
=
CacheConfig
(
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
self
.
max_model_len
,
# neuron needs block_size = max_model_len
gpu_memory_utilization
=
self
.
gpu_memory_utilization
,
swap_space
=
self
.
swap_space
,
cache_dtype
=
self
.
kv_cache_dtype
,
...
...
@@ -910,6 +933,7 @@ class EngineArgs:
delay_factor
=
self
.
scheduler_delay_factor
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
embedding_mode
=
model_config
.
embedding_mode
,
is_multimodal_model
=
model_config
.
is_multimodal_model
,
preemption_mode
=
self
.
preemption_mode
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
send_delta_data
=
(
envs
.
VLLM_USE_RAY_SPMD_WORKER
...
...
vllm/engine/async_llm_engine.py
View file @
0640f227
import
asyncio
import
time
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
import
torch
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
...
...
@@ -15,7 +13,7 @@ from vllm.core.scheduler import SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
(
DecoderPromptComponents
,
LLMEngine
,
PromptComponents
)
PromptComponents
,
SchedulerOutputState
)
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
...
...
@@ -24,12 +22,12 @@ from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
print_warning_once
...
...
@@ -257,24 +255,11 @@ class RequestTracker:
return
not
self
.
_new_requests
.
empty
()
@
dataclass
class
SchedulerOutputState
:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output
:
Optional
[
SamplerOutput
]
=
None
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
class
_AsyncLLMEngine
(
LLMEngine
):
"""Extension of LLMEngine to add async methods."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
pipeline_parallel_size
=
\
self
.
parallel_config
.
pipeline_parallel_size
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
pipeline_parallel_size
)
]
async
def
step_async
(
self
,
virtual_engine
:
int
...
...
@@ -293,19 +278,37 @@ class _AsyncLLMEngine(LLMEngine):
cached_outputs
=
self
.
cached_scheduler_outputs
[
virtual_engine
]
seq_group_metadata_list
=
cached_outputs
.
seq_group_metadata_list
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# Clear outputs for each new scheduler iteration
ctx
.
request_outputs
.
clear
()
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
[
virtual_engine
].
schedule
()
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self
.
_cache_scheduler_outputs_for_multi_step
(
virtual_engine
,
seq_group_metadata_list
,
scheduler_outputs
)
virtual_engine
,
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
...
...
@@ -333,14 +336,22 @@ class _AsyncLLMEngine(LLMEngine):
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
execute_model_req
)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
output
=
[]
# Finish the current step for all the sequence groups.
...
...
@@ -349,77 +360,45 @@ class _AsyncLLMEngine(LLMEngine):
seq_group
.
finish_step
()
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
#
c
lear the cache if we have finished all the steps
#
C
lear the cache if we have finished all the steps
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
virtual_engine
]
=
SchedulerOutputState
()
request_outputs
=
self
.
_process_model_outputs
(
output
,
scheduler_outputs
.
scheduled_seq_groups
,
scheduler_outputs
.
ignored_seq_groups
,
seq_group_metadata_list
)
is_async
=
allow_async_output_proc
is_last_step
=
True
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
))
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
"Async postprocessor expects only a single output set"
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
else
:
request_outputs
=
[]
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
return
request_outputs
def
_has_remaining_steps
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
)
->
bool
:
if
(
not
self
.
scheduler_config
.
is_multi_step
or
not
seq_group_metadata_list
):
return
False
# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps
=
seq_group_metadata_list
[
0
].
state
.
remaining_steps
if
any
([
seq_group
.
state
.
remaining_steps
!=
ref_remaining_steps
for
seq_group
in
seq_group_metadata_list
[
1
:]
]):
raise
AssertionError
((
"All running sequence groups should "
"have the same remaining steps."
))
return
ref_remaining_steps
>
0
def
_cache_scheduler_outputs_for_multi_step
(
self
,
virtual_engine
:
int
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
scheduler_outputs
:
SchedulerOutputs
)
->
None
:
self
.
cached_scheduler_outputs
[
virtual_engine
].
seq_group_metadata_list
=
seq_group_metadata_list
self
.
cached_scheduler_outputs
[
virtual_engine
].
scheduler_outputs
=
\
scheduler_outputs
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
=
None
def
_get_last_sampled_token_ids
(
self
,
virtual_engine
:
int
)
->
Optional
[
torch
.
Tensor
]:
cached_last_output
=
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
if
(
self
.
scheduler_config
.
is_multi_step
and
self
.
parallel_config
.
pipeline_parallel_size
>
1
and
cached_last_output
is
not
None
and
cached_last_output
.
sampled_token_ids_cpu
is
not
None
):
return
cached_last_output
.
sampled_token_ids_cpu
return
None
# Multi-step case
return
ctx
.
request_outputs
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
assert
len
(
ctx
.
output_queue
)
==
0
def
_update_cached_scheduler_output
(
self
,
virtual_engine
:
int
,
output
:
List
[
Optional
[
SamplerOutput
]])
->
None
:
if
(
self
.
parallel_config
.
pipeline_parallel_size
>
1
and
len
(
output
)
>
0
and
output
[
0
]
is
not
None
):
last_output
=
output
[
-
1
]
assert
last_output
is
not
None
assert
last_output
.
sampled_token_ids_cpu
is
not
None
assert
last_output
.
sampled_token_ids
is
None
assert
last_output
.
sampled_token_probs
is
None
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
=
last_output
return
ctx
.
request_outputs
async
def
stop_remote_worker_execution_loop_async
(
self
)
->
None
:
"""Stop the remote worker execution loop."""
...
...
@@ -635,6 +614,17 @@ class AsyncLLMEngine:
self
.
log_requests
=
log_requests
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
# TODO: Currently, disabled for engine_use_ray, ask
# Cody/Will/Woosuk about this case.
self
.
use_process_request_outputs_callback
=
not
self
.
engine_use_ray
if
self
.
use_process_request_outputs_callback
:
self
.
engine
.
process_request_outputs_callback
=
\
self
.
process_request_outputs
if
self
.
engine_use_ray
:
print_warning_once
(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
...
...
@@ -702,6 +692,11 @@ class AsyncLLMEngine:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_xpu_executor
import
RayXPUExecutorAsync
executor_class
=
RayXPUExecutorAsync
elif
distributed_executor_backend
==
"mp"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.multiproc_xpu_executor
import
(
MultiprocessingXPUExecutorAsync
)
executor_class
=
MultiprocessingXPUExecutorAsync
else
:
raise
RuntimeError
(
"Not supported distributed execution model on XPU device."
)
...
...
@@ -873,13 +868,27 @@ class AsyncLLMEngine:
request_outputs
=
await
self
.
engine
.
step_async
(
virtual_engine
)
# Put the outputs into the corresponding streams.
finished
=
True
# If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
if
not
self
.
use_process_request_outputs_callback
:
all_finished
=
self
.
process_request_outputs
(
request_outputs
)
else
:
# For callback case, we only need to detect when all
# requests are finished
all_finished
=
all
(
request_output
.
finished
for
request_output
in
request_outputs
)
return
not
all_finished
def
process_request_outputs
(
self
,
request_outputs
)
->
bool
:
# Put the outputs into the corresponding streams.
all_finished
=
True
for
request_output
in
request_outputs
:
self
.
_request_tracker
.
process_request_output
(
request_output
,
verbose
=
self
.
log_requests
)
finished
=
finished
and
request_output
.
finished
all_
finished
=
all_
finished
and
request_output
.
finished
return
not
finished
return
all_
finished
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
if
self
.
engine_use_ray
:
...
...
vllm/engine/llm_engine.py
View file @
0640f227
import
functools
import
time
from
collections
import
deque
from
contextlib
import
contextmanager
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
Iterable
,
List
,
from
dataclasses
import
dataclass
,
field
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Type
,
Union
import
torch
from
typing_extensions
import
TypeVar
,
assert_never
import
vllm.envs
as
envs
...
...
@@ -29,6 +33,7 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
...
...
@@ -36,8 +41,7 @@ from vllm.pooling_params import PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
init_tracer
)
...
...
@@ -77,6 +81,28 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional
[
MultiModalDataDict
]]
@
dataclass
class
SchedulerOutputState
:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
allow_async_output_proc
:
bool
=
False
last_output
:
Optional
[
SamplerOutput
]
=
None
@
dataclass
class
SchedulerContext
:
output_queue
:
Deque
[
Tuple
[
Optional
[
List
[
SamplerOutput
]],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
bool
,
bool
]]
=
field
(
default_factory
=
lambda
:
deque
())
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
field
(
default_factory
=
lambda
:
[])
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
...
...
@@ -162,11 +188,15 @@ class LLMEngine:
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
# To improve performance, only final requests outputs may be required.
# If this set to true, then no intermediate outputs will be returned.
step_return_finished_only
:
bool
=
False
,
)
->
None
:
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
...
...
@@ -176,7 +206,8 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"enable_prefix_caching=%s)"
,
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s)"
,
VLLM_VERSION
,
model_config
.
model
,
speculative_config
,
...
...
@@ -184,6 +215,7 @@ class LLMEngine:
model_config
.
skip_tokenizer_init
,
model_config
.
tokenizer_mode
,
model_config
.
revision
,
model_config
.
override_neuron_config
,
model_config
.
rope_scaling
,
model_config
.
rope_theta
,
model_config
.
tokenizer_revision
,
...
...
@@ -205,7 +237,9 @@ class LLMEngine:
model_config
.
seed
,
model_config
.
served_model_name
,
scheduler_config
.
use_v2_block_manager
,
scheduler_config
.
num_scheduler_steps
,
cache_config
.
enable_prefix_caching
,
model_config
.
use_async_output_proc
,
)
# TODO(woosuk): Print more configs in debug mode.
from
vllm.plugins
import
load_general_plugins
...
...
@@ -224,6 +258,7 @@ class LLMEngine:
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
)
self
.
log_stats
=
log_stats
self
.
step_return_finished_only
=
step_return_finished_only
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
...
...
@@ -307,13 +342,36 @@ class LLMEngine:
# different process.
self
.
tokenizer
.
ping
()
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
scheduler_contexts
=
[
SchedulerContext
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
async_callbacks
=
[
functools
.
partial
(
self
.
_process_model_outputs
,
ctx
=
self
.
scheduler_contexts
[
v_id
])
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self
.
process_request_outputs_callback
=
None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self
.
scheduler
=
[
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
,
parallel_config
.
pipeline_parallel_size
)
for
_
in
range
(
parallel_config
.
pipeline_parallel_size
)
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
,
parallel_config
.
pipeline_parallel_size
,
self
.
async_callbacks
[
v_id
]
if
model_config
.
use_async_output_proc
else
None
)
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
]
# Metric Logging.
...
...
@@ -421,6 +479,13 @@ class LLMEngine:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_xpu_executor
import
RayXPUExecutor
executor_class
=
RayXPUExecutor
elif
distributed_executor_backend
==
"mp"
:
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger
.
error
(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, Please try ray instead."
)
else
:
from
vllm.executor.xpu_executor
import
XPUExecutor
executor_class
=
XPUExecutor
...
...
@@ -1163,34 +1228,68 @@ class LLMEngine:
return
def
_process_model_outputs
(
self
,
output
:
GenericSequence
[
Union
[
SamplerOutput
,
PoolerOutput
]],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
],
ignored_seq_groups
:
List
[
SequenceGroup
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
def
_process_model_outputs
(
self
,
ctx
:
SchedulerContext
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
sampler_output: Used with multi-step execution to provide
sampler_output of each step
is_last_output: Used with multi-step execution to indicate
the last step (of each multi-step group)
Returns RequestOutputs that can be returned to the client.
"""
now
=
time
.
time
()
# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group
=
create_output_by_sequence_group
(
output
,
num_seq_groups
=
len
(
scheduled_seq_groups
))
if
len
(
ctx
.
output_queue
)
==
0
:
return
None
# Get pending async postprocessor
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
)
=
ctx
.
output_queue
.
popleft
()
assert
outputs
is
not
None
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
scheduler_outputs
.
scheduled_seq_groups
)
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if
len
(
outputs
)
>
1
:
outputs_by_sequence_group
=
create_output_by_sequence_group
(
outputs
,
num_seq_groups
=
len
(
seq_group_metadata_list
))
else
:
outputs_by_sequence_group
=
outputs
finished_before
:
List
[
int
]
=
[]
finished_now
:
List
[
int
]
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
# Update the scheduled sequence groups with the model outputs.
for
scheduled_seq_group
,
outputs
,
seq_group_meta
in
zip
(
scheduled_seq_groups
,
output_by_sequence_group
,
seq_group_metadata_list
):
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
update_num_computed_tokens
(
scheduled_seq_group
.
token_chunk_size
)
if
output
is
not
None
and
len
(
output
)
>
0
:
for
o
in
output
:
if
seq_group
.
is_finished
():
finished_before
.
append
(
i
)
continue
if
len
(
outputs
)
>
1
:
output
=
outputs_by_sequence_group
[
i
]
else
:
output
=
[
outputs_by_sequence_group
[
0
][
i
]]
if
not
is_async
:
seq_group
.
update_num_computed_tokens
(
scheduled_seq_group
.
token_chunk_size
)
if
outputs
:
for
o
in
outputs
:
if
(
isinstance
(
o
,
SamplerOutput
)
and
seq_group
.
metrics
is
not
None
):
if
seq_group
.
metrics
.
model_forward_time
is
not
None
:
...
...
@@ -1205,30 +1304,105 @@ class LLMEngine:
else
:
seq_group
.
metrics
.
model_execute_time
=
(
o
.
model_execute_time
)
if
self
.
model_config
.
embedding_mode
:
self
.
_process_sequence_group_outputs
(
seq_group
,
outputs
)
continue
self
.
_process_sequence_group_outputs
(
seq_group
,
output
)
else
:
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
if
seq_group_meta
.
do_sample
:
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
is_async
)
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
outputs
)
if
seq_group_meta
.
do_sample
:
self
.
output_processor
.
process_outputs
(
seq_group
,
outputs
)
if
seq_group
.
is_finished
():
finished_now
.
append
(
i
)
#
Free the finished sequence groups.
for
scheduler
in
self
.
scheduler
:
schedule
r
.
free_finish
ed_seq_groups
()
#
Generate outputs for the requests that finished this iteration
for
i
in
finished_now
:
schedule
d_seq_group
=
scheduler_outputs
.
schedul
ed_seq_groups
[
i
]
# Create the outputs.
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
for
scheduled_seq_group
in
scheduled_seq_groups
:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_outputs
.
append
(
request_output
)
for
seq_group
in
ignored_seq_groups
:
ctx
.
request_outputs
.
append
(
request_output
)
# Free currently finished requests
if
finished_now
:
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
# For multi-step, do not create outputs each iteration
if
not
is_last_step
:
# Immediately process request outputs here (if callback is given)
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
return
# Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list
# must match with the indices
for
i
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
if
i
in
finished_before
or
i
in
finished_now
:
continue
# Avoids double processing
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
if
(
seq_group
.
is_finished
()
if
self
.
step_return_finished_only
else
True
):
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_outputs
.
append
(
request_output
)
return
request_outputs
ctx
.
request_outputs
.
append
(
request_output
)
# Immediately process request outputs here (if callback is given)
if
(
ctx
.
request_outputs
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
# LLMEngine/AsyncLLMEngine directly
if
is_async
:
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
outputs
,
finished_before
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
return
None
def
_advance_to_next_step
(
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done inside output processor, but it is
required if the worker is to perform async forward pass to next step.
"""
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
zip
(
seq_group_metadata_list
,
output
,
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
continue
seq_group
.
update_num_computed_tokens
(
seq_group_metadata
.
token_chunk_size
)
if
seq_group_metadata
.
do_sample
:
assert
len
(
sequence_group_outputs
.
samples
)
==
1
,
(
"Async output processor expects a single sample"
" (i.e sampling_params.n == 1 and no "
"sampling_params.best_of > 1)"
)
sample
=
sequence_group_outputs
.
samples
[
0
]
assert
len
(
seq_group
.
seqs
)
==
1
seq
=
seq_group
.
seqs
[
0
]
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
...
...
@@ -1286,16 +1460,60 @@ class LLMEngine:
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise."
)
if
self
.
scheduler_config
.
num_scheduler_steps
>
1
:
raise
NotImplementedError
(
"Multiple scheduler steps (multi-step) are only supported "
"through AsyncLLMEngine. "
)
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
[
0
].
schedule
()
# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0.
virtual_engine
=
0
# These are cached outputs from previous iterations. None if on first
# iteration
cached_outputs
=
self
.
cached_scheduler_outputs
[
virtual_engine
]
seq_group_metadata_list
=
cached_outputs
.
seq_group_metadata_list
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# Clear outputs for each new scheduler iteration
ctx
.
request_outputs
.
clear
()
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self
.
_cache_scheduler_outputs_for_multi_step
(
virtual_engine
,
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
0
].
get_and_reset_finished_requests_ids
()
virtual_engine
].
get_and_reset_finished_requests_ids
()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
...
...
@@ -1303,23 +1521,74 @@ class LLMEngine:
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
)
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# No outputs in this case
output
=
[]
request_outputs
=
self
.
_process_model_outputs
(
output
,
scheduler_outputs
.
scheduled_seq_groups
,
scheduler_outputs
.
ignored_seq_groups
,
seq_group_metadata_list
)
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
for
seq_group
in
seq_group_metadata_list
:
seq_group
.
finish_step
()
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# clear the cache if we have finished all the steps.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
# Add results to the output_queue
is_async
=
allow_async_output_proc
is_last_step
=
True
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
))
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
(
"Async postprocessor expects only a single output set"
)
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
# Check if need to run the usual non-async path
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
else
:
# Multi-step case
return
ctx
.
request_outputs
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
assert
len
(
ctx
.
output_queue
)
==
0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
...
...
@@ -1327,32 +1596,97 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
request_outputs
return
ctx
.
request_outputs
def
_has_remaining_steps
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
)
->
bool
:
if
(
not
self
.
scheduler_config
.
is_multi_step
or
not
seq_group_metadata_list
):
return
False
# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps
=
seq_group_metadata_list
[
0
].
state
.
remaining_steps
if
any
([
seq_group
.
state
.
remaining_steps
!=
ref_remaining_steps
for
seq_group
in
seq_group_metadata_list
[
1
:]
]):
raise
AssertionError
((
"All running sequence groups should "
"have the same remaining steps."
))
return
ref_remaining_steps
>
0
def
_cache_scheduler_outputs_for_multi_step
(
self
,
virtual_engine
:
int
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
scheduler_outputs
:
SchedulerOutputs
,
allow_async_output_proc
:
bool
)
->
None
:
co
=
self
.
cached_scheduler_outputs
[
virtual_engine
]
co
.
seq_group_metadata_list
=
seq_group_metadata_list
co
.
scheduler_outputs
=
scheduler_outputs
co
.
allow_async_output_proc
=
allow_async_output_proc
co
.
last_output
=
None
def
_update_cached_scheduler_output
(
self
,
virtual_engine
:
int
,
output
:
List
[
Optional
[
SamplerOutput
]])
->
None
:
if
(
self
.
parallel_config
.
pipeline_parallel_size
>
1
and
len
(
output
)
>
0
and
output
[
0
]
is
not
None
):
last_output
=
output
[
-
1
]
assert
last_output
is
not
None
assert
last_output
.
sampled_token_ids_cpu
is
not
None
assert
last_output
.
sampled_token_ids
is
None
assert
last_output
.
sampled_token_probs
is
None
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
=
last_output
def
_get_last_sampled_token_ids
(
self
,
virtual_engine
:
int
)
->
Optional
[
torch
.
Tensor
]:
cached_last_output
=
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
if
(
self
.
scheduler_config
.
is_multi_step
and
self
.
parallel_config
.
pipeline_parallel_size
>
1
and
cached_last_output
is
not
None
and
cached_last_output
.
sampled_token_ids_cpu
is
not
None
):
return
cached_last_output
.
sampled_token_ids_cpu
return
None
def
add_logger
(
self
,
logger_name
:
str
,
logger
:
StatLoggerBase
)
->
None
:
if
not
self
.
log_stats
:
raise
RuntimeError
(
"Stat logging is disabled. Set `disable_log_stats=False` "
"argument to enable."
)
if
logger_name
in
self
.
stat_loggers
:
raise
KeyError
(
f
"Logger with name
{
logger_name
}
already exists."
)
self
.
stat_loggers
[
logger_name
]
=
logger
def
remove_logger
(
self
,
logger_name
:
str
)
->
None
:
if
not
self
.
log_stats
:
raise
RuntimeError
(
"Stat logging is disabled. Set `disable_log_stats=False` "
"argument to enable."
)
if
logger_name
not
in
self
.
stat_loggers
:
raise
KeyError
(
f
"Logger with name
{
logger_name
}
does not exist."
)
del
self
.
stat_loggers
[
logger_name
]
def
do_log_stats
(
self
,
scheduler
_output
s
:
Optional
[
Schedu
lerOutput
s
]
=
None
,
model_output
:
Optional
[
List
[
SamplerOutpu
t
]]
=
None
)
->
None
:
def
do_log_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
model
_output
:
Optional
[
List
[
Samp
lerOutput
]
]
=
None
,
finished_before
:
Optional
[
List
[
in
t
]]
=
None
)
->
None
:
"""Forced log when no requests active."""
if
self
.
log_stats
:
stats
=
self
.
_get_stats
(
scheduler_outputs
,
model_output
)
stats
=
self
.
_get_stats
(
scheduler_outputs
,
model_output
,
finished_before
)
for
logger
in
self
.
stat_loggers
.
values
():
logger
.
log
(
stats
)
def
_get_stats
(
self
,
scheduler
_output
s
:
Optional
[
SchedulerOutputs
]
,
model_output
:
Optional
[
List
[
SamplerOutpu
t
]]
=
None
)
->
Stats
:
def
_get_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
,
model
_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
,
finished_before
:
Optional
[
List
[
in
t
]]
=
None
)
->
Stats
:
"""Get Stats to be Logged to Prometheus.
Args:
...
...
@@ -1417,6 +1751,10 @@ class LLMEngine:
# NOTE: This loop assumes prefill seq_groups are before
# decode seq_groups in scheduled_seq_groups.
if
scheduler_outputs
is
not
None
:
# For async postprocessor, already finished sequences need to be
# not counted (to avoid double counting)
actual_num_batched_tokens
=
scheduler_outputs
.
num_batched_tokens
# type: ignore
num_generation_tokens_from_prefill_groups
=
0.
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
# the len of scheduler_outputs.scheduled_seq_groups is !=
...
...
@@ -1425,6 +1763,11 @@ class LLMEngine:
for
idx
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
# Skip double logging when using async output proc
if
finished_before
and
idx
in
finished_before
:
actual_num_batched_tokens
-=
1
continue
group_was_prefill
=
idx
<
scheduler_outputs
.
num_prefill_groups
seq_group
=
scheduled_seq_group
.
seq_group
...
...
@@ -1459,7 +1802,6 @@ class LLMEngine:
# Latency timings
time_e2e_requests
.
append
(
now
-
seq_group
.
metrics
.
arrival_time
)
# Metadata
num_prompt_tokens_requests
.
append
(
len
(
seq_group
.
prompt_token_ids
))
...
...
@@ -1483,7 +1825,7 @@ class LLMEngine:
# + num_generation_tokens_from_prefill_groups (since we generate
# one token on prefills on iters where the prefill finishes).
num_generation_tokens_iter
=
(
scheduler_outputs
.
num_batched_tokens
-
num_prompt_tokens_iter
+
actual_
num_batched_tokens
-
num_prompt_tokens_iter
+
num_generation_tokens_from_prefill_groups
)
# Spec decode, if enabled, emits specialized metrics from the worker in
...
...
@@ -1633,7 +1975,26 @@ class LLMEngine:
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
]):
prompt_key
=
"encoder_prompt_token_ids"
\
if
self
.
is_encoder_decoder_model
()
else
"prompt_token_ids"
if
not
inputs
.
get
(
prompt_key
):
raise
ValueError
(
"Prompt cannot be empty"
)
\ No newline at end of file
if
self
.
is_encoder_decoder_model
():
prompt_ids
=
inputs
.
get
(
"encoder_prompt_token_ids"
)
else
:
prompt_ids
=
inputs
.
get
(
"prompt_token_ids"
)
if
prompt_ids
is
None
or
len
(
prompt_ids
)
==
0
:
raise
ValueError
(
"Prompt cannot be empty"
)
if
self
.
model_config
.
is_multimodal_model
:
max_prompt_len
=
self
.
model_config
.
max_model_len
if
len
(
prompt_ids
)
>
max_prompt_len
:
raise
ValueError
(
f
"The prompt (total length
{
len
(
prompt_ids
)
}
) is too long "
f
"to fit into the model (context length
{
max_prompt_len
}
). "
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well."
)
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
vllm/engine/output_processor/interfaces.py
View file @
0640f227
...
...
@@ -40,13 +40,9 @@ class SequenceGroupOutputProcessor(ABC):
# Importing here to avoid cycle.
from
vllm.engine.output_processor.single_step
import
(
SingleStepOutputProcessor
)
return
SingleStepOutputProcessor
(
scheduler_config
,
detokenizer
,
scheduler
,
seq_counter
,
stop_checker
,
)
return
SingleStepOutputProcessor
(
scheduler_config
,
detokenizer
,
scheduler
,
seq_counter
,
stop_checker
)
else
:
# Importing here to avoid cycle.
from
vllm.engine.output_processor.multi_step
import
(
...
...
@@ -61,7 +57,8 @@ class SequenceGroupOutputProcessor(ABC):
@
abstractmethod
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
outputs
:
List
[
SequenceGroupOutput
],
is_async
:
bool
)
->
None
:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
...
...
vllm/engine/output_processor/multi_step.py
View file @
0640f227
...
...
@@ -4,6 +4,8 @@ from typing import Callable, List
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.interfaces
import
(
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.single_step
import
(
single_step_process_prompt_logprob
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -46,9 +48,16 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def
process_prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
self
.
_log_prompt_logprob_unsupported_warning_once
()
"""Process prompt logprobs associated with each step of a multi-step-
scheduled computation.
Args:
seq_group: the outputs are associated with this :class:`SequenceGroup`
outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
"""
for
output
in
outputs
:
# Concatenate single-step prompt logprob processing results.
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
@
staticmethod
@
functools
.
lru_cache
()
...
...
@@ -57,37 +66,73 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers)."
)
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
],
is_async
:
bool
=
False
)
->
None
:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
one new token per sequence.
This applies logic like stop condition checking and detokenization,
including freeing finished sequences. It also handles cases where there
are tokens emitted after the EOS token.
This applies logic like stop condition checking and detokenization.
It also handles cases where there are tokens emitted after
the EOS token.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
# Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
# if a client disconnects from the api server.
seqs
=
sequence_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
if
seqs
is
None
:
seqs
=
sequence_group
.
get_seqs
(
status
=
SequenceStatus
.
FINISHED_ABORTED
)
assert
seqs
,
"
e
xpected
running
sequences"
assert
seqs
,
"
E
xpected
RUNNING or FINISHED_ABORTED
sequences"
assert
len
(
seqs
)
==
1
,
(
"Beam search not supported in multi-step decoding."
)
seq
=
seqs
[
0
]
# Since there's only one sequence per sequence group, we can take the
# first sample.
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
-
1
]
assert
valid_samples
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
sequence_group
.
sampling_params
)
if
is_async
:
# Async case: We process tokens one by one. Here, we know the token
# was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic
self
.
_process_decode_and_stop
(
seq
,
sequence_group
.
sampling_params
)
else
:
# Standard multi-step case
# Since there's only one sequence per sequence group,
# we can take the first sample.
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
-
1
]
assert
valid_samples
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
sequence_group
.
sampling_params
)
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
new_char_count
=
0
if
sampling_params
.
detokenize
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
# TODO(sang): Support lora.
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
new_char_count
,
sampling_params
=
sampling_params
,
)
def
_process_seq_outputs
(
self
,
seq
:
Sequence
,
valid_samples
:
List
[
SequenceOutput
],
...
...
@@ -125,20 +170,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
logprobs
=
output_logprob
,
)
new_char_count
=
0
if
sampling_params
.
detokenize
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
self
.
_process_decode_and_stop
(
seq
,
sampling_params
)
# TODO(sang): Support lora.
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
new_char_count
,
sampling_params
=
sampling_params
,
)
if
seq
.
is_finished
():
break
if
seq
.
is_finished
():
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_seq
(
seq
)
vllm/engine/output_processor/single_step.py
View file @
0640f227
...
...
@@ -15,6 +15,44 @@ from vllm.utils import Counter
logger
=
init_logger
(
__name__
)
def
single_step_process_prompt_logprob
(
sg_output_proc
:
SequenceGroupOutputProcessor
,
seq_group
:
SequenceGroup
,
output
:
SequenceGroupOutput
)
->
None
:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step.
Do nothing if the output has no prompt logprobs.
Account for the fact that transformers do not compute first-token logprobs.
Args:
sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
"""
prompt_logprobs
=
output
.
prompt_logprobs
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if
prompt_logprobs
is
not
None
:
if
not
seq_group
.
prompt_logprobs
:
prompt_logprobs
=
[
None
]
+
prompt_logprobs
seq_group
.
prompt_logprobs
=
[]
assert
hasattr
(
sg_output_proc
,
'detokenizer'
)
if
(
seq_group
.
sampling_params
.
detokenize
and
sg_output_proc
.
detokenizer
):
sg_output_proc
.
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
prompt_logprobs
,
position_offset
=
len
(
seq_group
.
prompt_logprobs
))
seq_group
.
prompt_logprobs
.
extend
(
prompt_logprobs
)
class
SingleStepOutputProcessor
(
SequenceGroupOutputProcessor
):
"""SequenceGroupOutputProcessor which handles "output processing" logic,
which happens after the model returns generated token ids and before
...
...
@@ -29,14 +67,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
that is currently difficult to schedule multiple steps ahead of time.
"""
def
__init__
(
self
,
scheduler_config
:
SchedulerConfig
,
detokenizer
:
Detokenizer
,
scheduler
:
List
[
Scheduler
],
seq_counter
:
Counter
,
stop_checker
:
StopChecker
,
):
def
__init__
(
self
,
scheduler_config
:
SchedulerConfig
,
detokenizer
:
Detokenizer
,
scheduler
:
List
[
Scheduler
],
seq_counter
:
Counter
,
stop_checker
:
StopChecker
):
self
.
scheduler_config
=
scheduler_config
self
.
detokenizer
=
detokenizer
self
.
scheduler
=
scheduler
...
...
@@ -44,50 +77,49 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
self
.
stop_checker
=
stop_checker
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
outputs
:
List
[
SequenceGroupOutput
],
is_async
:
bool
)
->
None
:
"""Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
assert
(
len
(
outputs
)
==
1
),
f
"
{
type
(
self
)
}
does not support multiple outputs per step"
return
self
.
_process_sequence_group_outputs
(
sequence_group
,
outputs
[
0
])
return
self
.
_process_sequence_group_outputs
(
sequence_group
,
outputs
[
0
],
is_async
)
def
process_prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
"""Process prompt logprobs associated with one step of a single-step-
scheduled computation.
Args:
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
"""
assert
len
(
outputs
)
==
1
,
(
"Single step should only has 1 output."
)
output
=
outputs
[
0
]
prompt_logprobs
=
output
.
prompt_logprobs
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if
prompt_logprobs
is
not
None
:
if
not
seq_group
.
prompt_logprobs
:
prompt_logprobs
=
[
None
]
+
prompt_logprobs
seq_group
.
prompt_logprobs
=
[]
if
seq_group
.
sampling_params
.
detokenize
and
self
.
detokenizer
:
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
prompt_logprobs
,
position_offset
=
len
(
seq_group
.
prompt_logprobs
))
seq_group
.
prompt_logprobs
.
extend
(
prompt_logprobs
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
)
->
None
:
outputs
:
SequenceGroupOutput
,
is_async
:
bool
)
->
None
:
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
.
n
==
1
and
not
sampling_params
.
use_beam_search
:
if
sampling_params
.
best_of
==
1
and
not
sampling_params
.
use_beam_search
:
# only have one output sample
sample
=
outputs
.
samples
[
0
]
# only have one sequence
seq
=
seq_group
.
seqs
[
0
]
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
if
not
is_async
:
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
...
...
@@ -104,6 +136,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduler
.
free_seq
(
seq
)
return
# TODO: Add support for async for beam search
assert
not
is_async
# Process samples
samples
=
outputs
.
samples
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
...
...
vllm/engine/output_processor/util.py
View file @
0640f227
...
...
@@ -2,7 +2,8 @@ from typing import List
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
vllm.sequence
import
PoolerOutput
,
SamplerOutput
,
SequenceGroupOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
PoolerOutput
,
SequenceGroupOutput
def
create_output_by_sequence_group
(
...
...
vllm/engine/protocol.py
View file @
0640f227
...
...
@@ -5,11 +5,11 @@ from vllm.config import DecodingConfig, ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
PromptInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
...
...
vllm/entrypoints/chat_utils.py
View file @
0640f227
import
asyncio
import
codecs
from
dataclasses
import
dataclass
from
functools
import
lru_cache
import
json
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
from
functools
import
lru_cache
,
partial
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
Union
)
from
typing
import
(
Any
,
Awaitable
,
Dict
,
Generic
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
# yapf conflicts with isort for this block
# yapf: disable
from
openai.types.chat
import
ChatCompletionContentPartImageParam
from
openai.types.chat
import
(
ChatCompletionAssistantMessageParam
,
ChatCompletionContentPartImageParam
)
from
openai.types.chat
import
(
ChatCompletionContentPartParam
as
OpenAIChatCompletionContentPartParam
)
from
openai.types.chat
import
ChatCompletionContentPartTextParam
from
openai.types.chat
import
(
ChatCompletionContentPartRefusalParam
,
ChatCompletionContentPartTextParam
)
from
openai.types.chat
import
(
ChatCompletionMessageParam
as
OpenAIChatCompletionMessageParam
)
from
openai.types.chat
import
(
ChatCompletionMessageToolCallParam
,
ChatCompletionToolMessageParam
)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from
pydantic
import
ConfigDict
,
TypeAdapter
from
pydantic
import
ConfigDict
from
typing_extensions
import
Required
,
TypeAlias
,
TypedDict
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
(
async_get_and_parse_audio
,
async_get_and_parse_image
)
async_get_and_parse_image
,
get_and_parse_audio
,
get_and_parse_image
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
...
...
@@ -51,7 +59,8 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
ChatCompletionContentPartParam
:
TypeAlias
=
Union
[
OpenAIChatCompletionContentPartParam
,
ChatCompletionContentPartAudioParam
,
CustomChatCompletionContentPartParam
,
]
ChatCompletionContentPartRefusalParam
,
CustomChatCompletionContentPartParam
]
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
...
...
@@ -69,21 +78,217 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
same role.
"""
tool_call_id
:
Optional
[
str
]
"""Tool call that this message is responding to."""
tool_calls
:
Optional
[
Iterable
[
ChatCompletionMessageToolCallParam
]]
"""The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam
=
Union
[
OpenAIChatCompletionMessageParam
,
CustomChatCompletionMessageParam
]
# TODO: Make fields ReadOnly once mypy supports it
class
ConversationMessage
(
TypedDict
):
role
:
str
content
:
str
class
ConversationMessage
(
TypedDict
,
total
=
False
):
role
:
Required
[
str
]
"""The role of the message's author."""
content
:
Optional
[
str
]
"""The contents of the message"""
tool_call_id
:
Optional
[
str
]
"""Tool call that this message is responding to."""
name
:
Optional
[
str
]
"""The name of the function to call"""
tool_calls
:
Optional
[
Iterable
[
ChatCompletionMessageToolCallParam
]]
"""The tool calls generated by the model, such as function calls."""
ModalityStr
=
Literal
[
"image"
,
"audio"
]
_T
=
TypeVar
(
"_T"
)
class
BaseMultiModalItemTracker
(
ABC
,
Generic
[
_T
]):
"""
Tracks multi-modal items in a given request and ensures that the number
of multi-modal items in a given request does not exceed the configured
maximum per prompt.
"""
def
__init__
(
self
,
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
):
super
().
__init__
()
self
.
_model_config
=
model_config
self
.
_tokenizer
=
tokenizer
self
.
_allowed_items
=
(
model_config
.
multimodal_config
.
limit_per_prompt
if
model_config
.
multimodal_config
else
{})
self
.
_consumed_items
=
{
k
:
0
for
k
in
self
.
_allowed_items
}
self
.
_items
:
List
[
_T
]
=
[]
@
staticmethod
@
lru_cache
(
maxsize
=
None
)
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
return
tokenizer
.
decode
(
token_index
)
def
_placeholder_str
(
self
,
modality
:
ModalityStr
,
current_count
:
int
)
->
Optional
[
str
]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
hf_config
=
self
.
_model_config
.
hf_config
model_type
=
hf_config
.
model_type
if
modality
==
"image"
:
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
return
f
"<|image_
{
current_count
}
|>"
if
model_type
==
"minicpmv"
:
return
"(<image>./</image>)"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
):
# These models do not use image tokens in the prompt
return
None
if
model_type
.
startswith
(
"llava"
):
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
hf_config
.
image_token_index
)
if
model_type
in
(
"chameleon"
,
"internvl_chat"
):
return
"<image>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
elif
modality
==
"audio"
:
if
model_type
==
"ultravox"
:
return
"<|reserved_special_token_0|>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
else
:
raise
TypeError
(
f
"Unknown modality:
{
modality
}
"
)
@
staticmethod
def
_combine
(
items
:
List
[
MultiModalDataDict
])
->
MultiModalDataDict
:
mm_lists
:
Mapping
[
str
,
List
[
object
]]
=
defaultdict
(
list
)
# Merge all the multi-modal items
for
single_mm_data
in
items
:
for
mm_key
,
mm_item
in
single_mm_data
.
items
():
if
isinstance
(
mm_item
,
list
):
mm_lists
[
mm_key
].
extend
(
mm_item
)
else
:
mm_lists
[
mm_key
].
append
(
mm_item
)
# Unpack any single item lists for models that don't expect multiple.
return
{
mm_key
:
mm_list
[
0
]
if
len
(
mm_list
)
==
1
else
mm_list
for
mm_key
,
mm_list
in
mm_lists
.
items
()
}
def
add
(
self
,
modality
:
ModalityStr
,
item
:
_T
)
->
Optional
[
str
]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count
=
self
.
_allowed_items
.
get
(
modality
,
1
)
current_count
=
self
.
_consumed_items
.
get
(
modality
,
0
)
+
1
if
current_count
>
allowed_count
:
raise
ValueError
(
f
"At most
{
allowed_count
}
{
modality
}
(s) may be provided in "
"one request."
)
self
.
_consumed_items
[
modality
]
=
current_count
self
.
_items
.
append
(
item
)
return
self
.
_placeholder_str
(
modality
,
current_count
)
@
abstractmethod
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
raise
NotImplementedError
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
MultiModalDataDict
]):
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
return
self
.
_combine
(
self
.
_items
)
if
self
.
_items
else
None
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
return
MultiModalContentParser
(
self
)
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
MultiModalDataDict
]]):
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
self
.
_items
:
items
=
await
asyncio
.
gather
(
*
self
.
_items
)
return
self
.
_combine
(
items
)
return
None
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
return
AsyncMultiModalContentParser
(
self
)
class
BaseMultiModalContentParser
(
ABC
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# multimodal placeholder_string : count
self
.
_placeholder_counts
:
Dict
[
str
,
int
]
=
defaultdict
(
lambda
:
0
)
def
_add_placeholder
(
self
,
placeholder
:
Optional
[
str
]):
if
placeholder
:
self
.
_placeholder_counts
[
placeholder
]
+=
1
def
mm_placeholder_counts
(
self
)
->
Dict
[
str
,
int
]:
return
dict
(
self
.
_placeholder_counts
)
@
abstractmethod
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
raise
NotImplementedError
class
MultiModalContentParser
(
BaseMultiModalContentParser
):
def
__init__
(
self
,
tracker
:
MultiModalItemTracker
)
->
None
:
super
().
__init__
()
@
dataclass
(
frozen
=
True
)
class
ChatMessageParseResult
:
messages
:
List
[
ConversationMessage
]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
self
.
_tracker
=
tracker
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
image
=
get_and_parse_image
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio
=
get_and_parse_audio
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
)
self
.
_add_placeholder
(
placeholder
)
class
AsyncMultiModalContentParser
(
BaseMultiModalContentParser
):
def
__init__
(
self
,
tracker
:
AsyncMultiModalItemTracker
)
->
None
:
super
().
__init__
()
self
.
_tracker
=
tracker
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
image_coro
=
async_get_and_parse_image
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio_coro
=
async_get_and_parse_audio
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
)
self
.
_add_placeholder
(
placeholder
)
def
load_chat_template
(
...
...
@@ -112,152 +317,150 @@ def load_chat_template(
return
resolved_chat_template
@
lru_cache
(
maxsize
=
None
)
def
_mm_token_str
(
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
modality
:
Literal
[
"image"
,
"audio"
])
->
Optional
[
str
]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type
=
model_config
.
hf_config
.
model_type
if
modality
==
"image"
:
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
return
"<|image_1|>"
if
model_type
==
"minicpmv"
:
return
"(<image>./</image>)"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
):
# These models do not use image tokens in the prompt
return
None
if
model_type
.
startswith
(
"llava"
):
return
tokenizer
.
decode
(
model_config
.
hf_config
.
image_token_index
)
if
model_type
in
(
"chameleon"
,
"internvl_chat"
):
return
"<image>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
elif
modality
==
"audio"
:
if
model_type
==
"ultravox"
:
return
"<|reserved_special_token_0|>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
else
:
raise
TypeError
(
f
"Unknown modality:
{
modality
}
"
)
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def
_get_full_multimodal_text_prompt
(
placeholder_
token_str
:
str
,
def
_get_full_multimodal_text_prompt
(
placeholder_
counts
:
Dict
[
str
,
int
]
,
text_prompt
:
str
)
->
str
:
"""Combine multimodal prompts for a multimodal language model"""
"""Combine multimodal prompts for a multimodal language model."""
# Look through the text prompt to check for missing placeholders
missing_placeholders
:
List
[
str
]
=
[]
for
placeholder
in
placeholder_counts
:
# For any existing placeholder in the text prompt, we leave it as is
placeholder_counts
[
placeholder
]
-=
text_prompt
.
count
(
placeholder
)
if
placeholder_counts
[
placeholder
]
<
0
:
raise
ValueError
(
f
"Found more '
{
placeholder
}
' placeholders in input prompt than "
"actual multimodal data items."
)
# NOTE: For now we assume all model architectures use the same
# placeholder + text prompt format. This may change in the future.
return
f
"
{
placeholder_token_str
}
\n
{
text_prompt
}
"
missing_placeholders
.
extend
([
placeholder
]
*
placeholder_counts
[
placeholder
])
# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
return
"
\n
"
.
join
(
missing_placeholders
+
[
text_prompt
])
_TextParser
=
TypeAdapter
(
ChatCompletionContentPartTextParam
)
_ImageParser
=
TypeAdapter
(
ChatCompletionContentPartImageParam
)
_AudioParser
=
TypeAdapter
(
ChatCompletionContentPartAudioParam
)
# No need to validate using Pydantic again
_TextParser
=
partial
(
cast
,
ChatCompletionContentPartTextParam
)
_ImageParser
=
partial
(
cast
,
ChatCompletionContentPartImageParam
)
_AudioParser
=
partial
(
cast
,
ChatCompletionContentPartAudioParam
)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
def
_parse_chat_message_content_parts
(
role
:
str
,
parts
:
Iterable
[
ChatCompletionContentPartParam
],
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
)
->
ChatMessageParseResult
:
mm_tracker
:
BaseMultiModalItemTracker
,
)
->
List
[
ConversationMessage
]:
texts
:
List
[
str
]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
m
odality
:
Literal
[
"image"
,
"audio"
]
=
"image"
m
m_parser
=
mm_tracker
.
create_parser
()
for
part
in
parts
:
part_type
=
part
[
"type"
]
if
part_type
==
"text"
:
text
=
_TextParser
.
validate_python
(
part
)[
"text"
]
text
=
_TextParser
(
part
)[
"text"
]
texts
.
append
(
text
)
elif
part_type
==
"image_url"
:
modality
=
"image"
if
len
(
mm_futures
)
>
0
:
raise
NotImplementedError
(
"Multiple multimodal inputs is currently not supported."
)
image_url
=
_ImageParser
.
validate_python
(
part
)[
"image_url"
]
image_url
=
_ImageParser
(
part
)[
"image_url"
]
if
image_url
.
get
(
"detail"
,
"auto"
)
!=
"auto"
:
logger
.
warning
(
"'image_url.detail' is currently not supported and "
"will be ignored."
)
image_future
=
async_get_and_parse_image
(
image_url
[
"url"
])
mm_futures
.
append
(
image_future
)
mm_parser
.
parse_image
(
image_url
[
"url"
])
elif
part_type
==
"audio_url"
:
modality
=
"audio"
if
len
(
mm_futures
)
>
0
:
raise
NotImplementedError
(
"Multiple multimodal inputs is currently not supported."
)
audio_url
=
_AudioParser
.
validate_python
(
part
)[
"audio_url"
]
audio_future
=
async_get_and_parse_audio
(
audio_url
[
"url"
])
mm_futures
.
append
(
audio_future
)
audio_url
=
_AudioParser
(
part
)[
"audio_url"
]
mm_parser
.
parse_audio
(
audio_url
[
"url"
])
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
text_prompt
=
"
\n
"
.
join
(
texts
)
mm_placeholder_counts
=
mm_parser
.
mm_placeholder_counts
()
if
mm_placeholder_counts
:
text_prompt
=
_get_full_multimodal_text_prompt
(
mm_placeholder_counts
,
text_prompt
)
if
mm_futures
:
placeholder_token_str
=
_mm_token_str
(
model_config
,
tokenizer
,
modality
)
if
placeholder_token_str
is
not
None
:
if
placeholder_token_str
in
text_prompt
:
logger
.
warning
(
"Detected multi-modal token string in the text prompt. "
"Skipping prompt formatting."
)
else
:
text_prompt
=
_get_full_multimodal_text_prompt
(
placeholder_token_str
=
placeholder_token_str
,
text_prompt
=
text_prompt
,
)
return
[
ConversationMessage
(
role
=
role
,
content
=
text_prompt
)]
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
text_prompt
)]
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
mm_futures
)
# No need to validate using Pydantic again
_AssistantParser
=
partial
(
cast
,
ChatCompletionAssistantMessageParam
)
_ToolParser
=
partial
(
cast
,
ChatCompletionToolMessageParam
)
def
_parse_chat_message_content
(
message
:
ChatCompletionMessageParam
,
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
)
->
ChatMessageParseResult
:
mm_tracker
:
BaseMultiModalItemTracker
,
)
->
List
[
ConversationMessage
]:
role
=
message
[
"role"
]
content
=
message
.
get
(
"content"
)
if
content
is
None
:
return
ChatMessageParseResult
(
messages
=
[],
mm_futures
=
[])
if
isinstance
(
content
,
str
):
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
[])
content
=
[]
elif
isinstance
(
content
,
str
):
content
=
[
ChatCompletionContentPartTextParam
(
type
=
"text"
,
text
=
content
)
]
re
turn
_parse_chat_message_content_parts
(
re
sult
=
_parse_chat_message_content_parts
(
role
,
content
,
# type: ignore
model_config
,
tokenizer
,
mm_tracker
,
)
for
result_msg
in
result
:
if
role
==
'assistant'
:
parsed_msg
=
_AssistantParser
(
message
)
if
"tool_calls"
in
parsed_msg
:
result_msg
[
"tool_calls"
]
=
list
(
parsed_msg
[
"tool_calls"
])
elif
role
==
"tool"
:
parsed_msg
=
_ToolParser
(
message
)
if
"tool_call_id"
in
parsed_msg
:
result_msg
[
"tool_call_id"
]
=
parsed_msg
[
"tool_call_id"
]
if
"name"
in
message
and
isinstance
(
message
[
"name"
],
str
):
result_msg
[
"name"
]
=
message
[
"name"
]
return
result
def
parse_chat_messages
(
messages
:
List
[
ChatCompletionMessageParam
],
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
)
->
Tuple
[
List
[
ConversationMessage
],
List
[
Awaitable
[
MultiModalDataDict
]]
]
:
)
->
Tuple
[
List
[
ConversationMessage
],
Optional
[
MultiModalDataDict
]]:
conversation
:
List
[
ConversationMessage
]
=
[]
mm_
futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
mm_
tracker
=
MultiModalItemTracker
(
model_config
,
tokenizer
)
for
msg
in
messages
:
parse_result
=
_parse_chat_message_content
(
msg
,
model_config
,
tokenizer
)
sub_messages
=
_parse_chat_message_content
(
msg
,
mm_tracker
)
conversation
.
extend
(
parse_result
.
messages
)
mm_futures
.
extend
(
parse_result
.
mm_futures
)
conversation
.
extend
(
sub_messages
)
return
conversation
,
mm_futures
return
conversation
,
mm_tracker
.
all_mm_data
()
def
parse_chat_messages_futures
(
messages
:
List
[
ChatCompletionMessageParam
],
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
)
->
Tuple
[
List
[
ConversationMessage
],
Awaitable
[
Optional
[
MultiModalDataDict
]]]:
conversation
:
List
[
ConversationMessage
]
=
[]
mm_tracker
=
AsyncMultiModalItemTracker
(
model_config
,
tokenizer
)
for
msg
in
messages
:
sub_messages
=
_parse_chat_message_content
(
msg
,
mm_tracker
)
conversation
.
extend
(
sub_messages
)
return
conversation
,
mm_tracker
.
all_mm_data
()
def
apply_chat_template
(
...
...
@@ -267,19 +470,31 @@ def apply_chat_template(
*
,
tokenize
:
bool
=
False
,
# Different from HF's default
**
kwargs
:
Any
,
)
->
str
:
)
->
Union
[
str
,
List
[
int
]]
:
if
chat_template
is
None
and
tokenizer
.
chat_template
is
None
:
raise
ValueError
(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for
message
in
conversation
:
if
(
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
and
isinstance
(
message
[
"tool_calls"
],
list
)):
for
i
in
range
(
len
(
message
[
"tool_calls"
])):
args
:
str
=
message
[
"tool_calls"
][
i
][
"function"
][
"arguments"
]
parsed_args
:
Dict
=
json
.
loads
(
args
)
message
[
"tool_calls"
][
i
][
"function"
][
"arguments"
]
=
parsed_args
prompt
=
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
chat_template
=
chat_template
,
tokenize
=
tokenize
,
**
kwargs
,
)
assert
isinstance
(
prompt
,
str
)
return
prompt
vllm/entrypoints/llm.py
View file @
0640f227
...
...
@@ -23,7 +23,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
deprecate_kwargs
from
vllm.utils
import
Counter
,
deprecate_kwargs
,
is_list_of
logger
=
init_logger
(
__name__
)
...
...
@@ -129,6 +129,7 @@ class LLM:
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
**
kwargs
,
)
->
None
:
'''
...
...
@@ -170,6 +171,7 @@ class LLM:
max_context_len_to_capture
=
max_context_len_to_capture
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_async_output_proc
=
disable_async_output_proc
,
**
kwargs
,
)
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
...
...
@@ -356,15 +358,18 @@ class LLM:
add_generation_prompt
:
bool
=
True
,
)
->
List
[
RequestOutput
]:
"""
Generate
s
responses for chat
messages
.
Generate responses for
a
chat
conversation
.
Converts the messages to prompts using the tokenizer and calls
the :meth:`generate` method to generate the responses.
The chat conversation is converted into a text prompt using the
tokenizer and calls the :meth:`generate` method to generate the
responses.
Multi-modal inputs can be passed in the same way you would pass them
to the OpenAI API.
Args:
messages: A list of messages to generate responses for. Each
message is a list of dictionaries with 'role' and 'content'
keys.
messages: A single conversation represented as a list of messages.
Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
...
...
@@ -385,18 +390,28 @@ class LLM:
tokenizer
=
self
.
get_tokenizer
()
model_config
=
self
.
llm_engine
.
get_model_config
()
conversation
s
,
_
=
parse_chat_messages
(
messages
,
model_config
,
tokenizer
)
conversation
,
mm_data
=
parse_chat_messages
(
messages
,
model_config
,
tokenizer
)
prompt
s
=
apply_chat_template
(
prompt
=
apply_chat_template
(
tokenizer
,
conversation
s
,
conversation
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
)
add_generation_prompt
=
add_generation_prompt
,
)
inputs
:
PromptInputs
if
is_list_of
(
prompt
,
int
):
inputs
=
TokensPrompt
(
prompt_token_ids
=
prompt
)
else
:
inputs
=
TextPrompt
(
prompt
=
prompt
)
if
mm_data
is
not
None
:
inputs
[
"multi_modal_data"
]
=
mm_data
return
self
.
generate
(
promp
ts
,
sampling_params
,
inpu
ts
,
sampling_params
=
sampling_params
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
)
...
...
@@ -603,7 +618,6 @@ class LLM:
inputs
=
[
inputs
]
num_requests
=
len
(
inputs
)
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and params "
"must be the same."
)
...
...
@@ -678,6 +692,10 @@ class LLM:
postfix
=
(
f
"est. speed input:
{
0
:.
2
f
}
toks/s, "
f
"output:
{
0
:.
2
f
}
toks/s"
),
)
# In the loop below, only finished outputs are used
self
.
llm_engine
.
step_return_finished_only
=
True
# Run the engine.
outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
total_in_toks
=
0
...
...
@@ -700,6 +718,10 @@ class LLM:
f
"est. speed input:
{
in_spd
:.
2
f
}
toks/s, "
f
"output:
{
out_spd
:.
2
f
}
toks/s"
)
pbar
.
update
(
1
)
# Restore original behavior
self
.
llm_engine
.
step_return_finished_only
=
False
if
use_tqdm
:
pbar
.
close
()
# Sort the outputs by request ID.
...
...
vllm/entrypoints/openai/api_server.py
View file @
0640f227
...
...
@@ -67,7 +67,7 @@ _running_tasks: Set[asyncio.Task] = set()
def
model_is_embedding
(
model_name
:
str
,
trust_remote_code
:
bool
,
quantization
:
str
)
->
bool
:
quantization
:
Optional
[
str
]
)
->
bool
:
return
ModelConfig
(
model
=
model_name
,
tokenizer
=
model_name
,
tokenizer_mode
=
"auto"
,
...
...
@@ -96,13 +96,6 @@ async def lifespan(app: FastAPI):
@
asynccontextmanager
async
def
build_async_engine_client
(
args
:
Namespace
)
->
AsyncIterator
[
Optional
[
AsyncEngineClient
]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
...
...
@@ -112,14 +105,37 @@ async def build_async_engine_client(
# Backend itself still global for the silly lil' health handler
global
async_engine_client
async
with
build_async_engine_client_from_engine_args
(
engine_args
,
args
.
disable_frontend_multiprocessing
)
as
engine
:
async_engine_client
=
engine
# type: ignore[assignment]
yield
engine
@
asynccontextmanager
async
def
build_async_engine_client_from_engine_args
(
engine_args
:
AsyncEngineArgs
,
disable_frontend_multiprocessing
:
bool
=
False
,
)
->
AsyncIterator
[
Optional
[
AsyncEngineClient
]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if
(
model_is_embedding
(
args
.
model
,
args
.
trust_remote_code
,
args
.
quantization
)
or
args
.
disable_frontend_multiprocessing
):
async_
engine_client
=
AsyncLLMEngine
.
from_engine_args
(
if
(
model_is_embedding
(
engine_
args
.
model
,
engine_
args
.
trust_remote_code
,
engine_
args
.
quantization
)
or
disable_frontend_multiprocessing
):
engine_client
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
yield
async_engine_client
try
:
yield
engine_client
finally
:
engine_client
.
shutdown_background_loop
()
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
...
...
@@ -148,7 +164,6 @@ async def build_async_engine_client(
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client
=
AsyncEngineRPCClient
(
rpc_path
)
async_engine_client
=
rpc_client
# type: ignore
# Start RPCServer in separate process (holds the AsyncLLMEngine).
context
=
multiprocessing
.
get_context
(
"spawn"
)
...
...
@@ -174,7 +189,7 @@ async def build_async_engine_client(
yield
None
return
yield
async_engine_client
yield
rpc_client
# type: ignore[misc]
finally
:
# Ensure rpc server process was terminated
rpc_server_process
.
terminate
()
...
...
@@ -218,7 +233,7 @@ def mount_metrics(app: FastAPI):
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
())
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
'
^/metrics(?P<path>.*)$
'
)
metrics_route
.
path_regex
=
re
.
compile
(
"
^/metrics(?P<path>.*)$
"
)
app
.
routes
.
append
(
metrics_route
)
...
...
@@ -268,11 +283,14 @@ async def show_version():
@
router
.
post
(
"/v1/chat/completions"
)
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_chat
.
create_chat_completion
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
elif
isinstance
(
generator
,
ChatCompletionResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
...
...
@@ -407,7 +425,8 @@ async def init_app(
request_logger
=
request_logger
,
chat_template
=
args
.
chat_template
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
)
enable_auto_tools
=
args
.
enable_auto_tool_choice
,
tool_parser
=
args
.
tool_call_parser
)
openai_serving_completion
=
OpenAIServingCompletion
(
async_engine_client
,
model_config
,
...
...
vllm/entrypoints/openai/cli_args.py
View file @
0640f227
...
...
@@ -163,6 +163,24 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help
=
"If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine."
)
parser
.
add_argument
(
"--enable-auto-tool-choice"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use"
)
parser
.
add_argument
(
"--tool-call-parser"
,
type
=
str
,
choices
=
[
"mistral"
,
"hermes"
],
default
=
None
,
help
=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice."
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
'--max-log-len'
,
...
...
vllm/entrypoints/openai/protocol.py
View file @
0640f227
...
...
@@ -5,8 +5,9 @@ from argparse import Namespace
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
import
torch
from
openai.types.chat
import
ChatCompletionContentPartParam
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
typing_extensions
import
Annotated
from
typing_extensions
import
Annotated
,
Required
,
TypedDict
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
...
...
@@ -35,6 +36,26 @@ assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert
_LONG_INFO
.
max
==
_MOCK_LONG_INFO
.
max
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
"""Enables custom roles in the Chat Completion API."""
role
:
Required
[
str
]
"""The role of the message's author."""
content
:
Union
[
str
,
List
[
ChatCompletionContentPartParam
]]
"""The contents of the message."""
name
:
str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
tool_call_id
:
Optional
[
str
]
tool_calls
:
Optional
[
List
[
dict
]]
class
OpenAIBaseModel
(
BaseModel
):
# OpenAI API does not allow extra fields
model_config
=
ConfigDict
(
extra
=
"forbid"
)
...
...
@@ -85,9 +106,19 @@ class UsageInfo(OpenAIBaseModel):
completion_tokens
:
Optional
[
int
]
=
0
class
JsonSchemaResponseFormat
(
OpenAIBaseModel
):
name
:
str
description
:
Optional
[
str
]
=
None
# schema is the field in openai but that causes conflicts with pydantic so
# instead use json_schema with an alias
json_schema
:
Optional
[
Dict
[
str
,
Any
]]
=
Field
(
default
=
None
,
alias
=
'schema'
)
strict
:
Optional
[
bool
]
=
None
class
ResponseFormat
(
OpenAIBaseModel
):
# type must be "json_object" or "text"
type
:
Literal
[
"text"
,
"json_object"
]
# type must be "json_schema", "json_object" or "text"
type
:
Literal
[
"text"
,
"json_object"
,
"json_schema"
]
json_schema
:
Optional
[
JsonSchemaResponseFormat
]
=
None
class
StreamOptions
(
OpenAIBaseModel
):
...
...
@@ -135,8 +166,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature
:
Optional
[
float
]
=
0.7
top_p
:
Optional
[
float
]
=
1.0
tools
:
Optional
[
List
[
ChatCompletionToolsParam
]]
=
None
tool_choice
:
Optional
[
Union
[
Literal
[
"none"
],
tool_choice
:
Optional
[
Union
[
Literal
[
"none"
],
Literal
[
"auto"
],
ChatCompletionNamedToolChoiceParam
]]
=
"none"
# NOTE this will be ignored by VLLM -- the model determines the behavior
parallel_tool_calls
:
Optional
[
bool
]
=
False
user
:
Optional
[
str
]
=
None
# doc: begin-chat-completion-sampling-params
...
...
@@ -318,6 +352,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_guided_decoding_count
(
cls
,
data
):
if
isinstance
(
data
,
ValueError
):
raise
data
guide_count
=
sum
([
"guided_json"
in
data
and
data
[
"guided_json"
]
is
not
None
,
"guided_regex"
in
data
and
data
[
"guided_regex"
]
is
not
None
,
...
...
@@ -329,21 +366,61 @@ class ChatCompletionRequest(OpenAIBaseModel):
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice')."
)
# you can only either use guided decoding or tools, not both
if
guide_count
>
1
and
"tool_choice"
in
data
and
data
[
"tool_choice"
]
!=
"none"
:
if
guide_count
>
1
and
data
.
get
(
"tool_choice"
,
"none"
)
not
in
(
"none"
,
"auto"
)
:
raise
ValueError
(
"You can only either use guided decoding or tools, not both."
)
return
data
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_tool_choice
(
cls
,
data
):
if
"tool_choice"
in
data
and
data
[
"tool_choice"
]
!=
"none"
:
if
not
isinstance
(
data
[
"tool_choice"
],
dict
):
raise
ValueError
(
"Currently only named tools are supported."
)
def
check_tool_usage
(
cls
,
data
):
# if "tool_choice" is not specified but tools are provided,
# default to "auto" tool_choice
if
"tool_choice"
not
in
data
and
"tools"
in
data
:
data
[
"tool_choice"
]
=
"auto"
# if "tool_choice" is specified -- validation
if
"tool_choice"
in
data
:
# ensure that if "tool choice" is specified, tools are present
if
"tools"
not
in
data
or
data
[
"tools"
]
is
None
:
raise
ValueError
(
"When using `tool_choice`, `tools` must be set."
)
# make sure that tool choice is either a named tool
# OR that it's set to "auto"
if
data
[
"tool_choice"
]
!=
"auto"
and
not
isinstance
(
data
[
"tool_choice"
],
dict
):
raise
ValueError
(
"`tool_choice` must either be a named tool or
\"
auto
\"
. "
"`tool_choice=
\"
none
\"
is not supported."
)
# ensure that if "tool_choice" is specified as an object,
# it matches a valid tool
if
isinstance
(
data
[
"tool_choice"
],
dict
):
valid_tool
=
False
specified_function
=
data
[
"tool_choice"
][
"function"
]
if
not
specified_function
:
raise
ValueError
(
"Incorrectly formatted `tool_choice`. Should be like "
"`{
\"
type
\"
:
\"
function
\"
,"
"
\"
function
\"
: {
\"
name
\"
:
\"
my_function
\"
}}`"
)
specified_function_name
=
specified_function
[
"name"
]
if
not
specified_function_name
:
raise
ValueError
(
"Incorrectly formatted `tool_choice`. Should be like "
"`{
\"
type
\"
:
\"
function
\"
, "
"
\"
function
\"
: {
\"
name
\"
:
\"
my_function
\"
}}`"
)
for
tool
in
data
[
"tools"
]:
if
tool
[
"function"
][
"name"
]
==
specified_function_name
:
valid_tool
=
True
break
if
not
valid_tool
:
raise
ValueError
(
"The tool specified in `tool_choice` does not match any"
" of the specified `tools`"
)
return
data
...
...
@@ -403,7 +480,7 @@ class CompletionRequest(OpenAIBaseModel):
)
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will follow the JSON schema."
)
,
description
=
"If specified, the output will follow the JSON schema."
,
)
guided_regex
:
Optional
[
str
]
=
Field
(
default
=
None
,
...
...
@@ -623,9 +700,41 @@ class ToolCall(OpenAIBaseModel):
function
:
FunctionCall
class
DeltaFunctionCall
(
BaseModel
):
name
:
Optional
[
str
]
=
None
arguments
:
Optional
[
str
]
=
None
# a tool call delta where everything is optional
class
DeltaToolCall
(
OpenAIBaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"chatcmpl-tool-
{
random_uuid
()
}
"
)
type
:
Literal
[
"function"
]
=
"function"
index
:
int
function
:
Optional
[
DeltaFunctionCall
]
=
None
# the initial delta that gets sent once a new tool call is started;
class
InitialDeltaToolCall
(
DeltaToolCall
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"chatcmpl-tool-
{
random_uuid
()
}
"
)
type
:
Literal
[
"function"
]
=
"function"
index
:
int
class
ExtractedToolCallInformation
(
BaseModel
):
# indicate if tools were called
tools_called
:
bool
# extracted tool calls
tool_calls
:
List
[
ToolCall
]
# content - per OpenAI spec, content AND tool calls can be returned rarely
# But some models will do this intentionally
content
:
Optional
[
str
]
=
None
class
ChatMessage
(
OpenAIBaseModel
):
role
:
str
content
:
str
content
:
Optional
[
str
]
=
None
tool_calls
:
List
[
ToolCall
]
=
Field
(
default_factory
=
list
)
...
...
@@ -647,7 +756,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
index
:
int
message
:
ChatMessage
logprobs
:
Optional
[
ChatCompletionLogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
# per OpenAI spec this is the default
finish_reason
:
Optional
[
str
]
=
"stop"
# not part of the OpenAI spec but included in vLLM for legacy reasons
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
None
...
...
@@ -664,7 +775,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
class
DeltaMessage
(
OpenAIBaseModel
):
role
:
Optional
[
str
]
=
None
content
:
Optional
[
str
]
=
None
tool_calls
:
List
[
ToolCall
]
=
Field
(
default_factory
=
list
)
tool_calls
:
List
[
Delta
ToolCall
]
=
Field
(
default_factory
=
list
)
class
ChatCompletionResponseStreamChoice
(
OpenAIBaseModel
):
...
...
vllm/entrypoints/openai/rpc/client.py
View file @
0640f227
import
asyncio
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
Any
,
AsyncGenerator
,
Mapping
,
Optional
from
typing
import
Any
,
AsyncGenerator
,
Iterator
,
Mapping
,
Optional
from
uuid
import
uuid4
import
cloudpickle
import
zmq
import
zmq.asyncio
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
...
...
@@ -101,6 +104,7 @@ class AsyncEngineRPCClient:
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit
=
self
.
context
.
get
(
zmq
.
constants
.
SOCKET_LIMIT
)
assert
isinstance
(
socket_limit
,
int
)
if
socket_limit
<
VLLM_RPC_SOCKET_LIMIT_CUTOFF
:
raise
ValueError
(
f
"Found zmq.constants.SOCKET_LIMIT=
{
socket_limit
}
, which caps "
...
...
@@ -114,18 +118,21 @@ class AsyncEngineRPCClient:
self
.
context
.
set
(
zmq
.
constants
.
MAX_SOCKETS
,
socket_limit
)
# IPC connection to RPC Server (uses unix sockets).
self
.
to_rpc_server
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
to_rpc_server
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
to_rpc_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
to_rpc_server
.
bind
(
rpc_path
)
# In process proxy to RPC Server (uses memory-based messaging).
self
.
from_api_server
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
self
.
from_api_server
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
self
.
from_api_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
from_api_server
.
bind
(
INPROC_PROXY_PATH
)
# Asyncio background task for the proxy.
self
.
proxy_task
=
asyncio
.
create_task
(
self
.
proxy_
in_
task
=
asyncio
.
create_task
(
self
.
run_proxy
(
self
.
from_api_server
,
self
.
to_rpc_server
))
self
.
proxy_out_task
=
asyncio
.
create_task
(
self
.
run_proxy
(
self
.
to_rpc_server
,
self
.
from_api_server
))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
...
...
@@ -135,20 +142,11 @@ class AsyncEngineRPCClient:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self
.
limit_concurrency
=
socket_limit
//
2
-
2
async
def
run_proxy
(
self
,
socket_from
,
socket_to
):
async
def
run_proxy
(
self
,
socket_from
:
Socket
,
socket_to
:
Socket
):
"""Background task that runs a proxy"""
poller
=
zmq
.
asyncio
.
Poller
()
poller
.
register
(
socket_from
,
zmq
.
constants
.
POLLIN
)
poller
.
register
(
socket_to
,
zmq
.
constants
.
POLLIN
)
while
True
:
events
=
await
poller
.
poll
()
events
=
dict
(
events
)
if
socket_from
in
events
:
identity
,
msg
=
await
socket_from
.
recv_multipart
()
await
socket_to
.
send_multipart
([
identity
,
msg
])
if
socket_to
in
events
:
identity
,
msg
=
await
socket_to
.
recv_multipart
()
await
socket_from
.
send_multipart
([
identity
,
msg
])
frames
=
await
socket_from
.
recv_multipart
(
copy
=
False
)
await
socket_to
.
send_multipart
(
frames
,
copy
=
False
)
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
...
...
@@ -179,7 +177,7 @@ class AsyncEngineRPCClient:
self
.
context
.
destroy
()
@
contextmanager
def
to_proxy_socket
(
self
):
def
to_proxy_socket
(
self
)
->
Iterator
[
Socket
]
:
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
...
...
@@ -207,7 +205,8 @@ class AsyncEngineRPCClient:
with
self
.
to_proxy_socket
()
as
socket
:
# Ping RPCServer with a request.
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
request
)])
await
socket
.
send_multipart
((
cloudpickle
.
dumps
(
request
),
),
copy
=
False
)
# Make sure the server responds
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
...
...
@@ -215,7 +214,9 @@ class AsyncEngineRPCClient:
f
"
{
self
.
_data_timeout
}
ms"
)
# Await the data from the Server.
data
=
cloudpickle
.
loads
(
await
socket
.
recv
())
frame
=
await
socket
.
recv
(
copy
=
False
)
assert
isinstance
(
frame
,
Frame
)
data
=
pickle
.
loads
(
frame
.
buffer
)
if
isinstance
(
data
,
Exception
):
# Re-raise exceptions returned by the server
...
...
@@ -233,23 +234,23 @@ class AsyncEngineRPCClient:
return
data
async
def
_send_one_way_rpc_request
(
self
,
request
:
RPC_REQUEST_TYPE
,
error_message
:
str
,
socket
:
Optional
[
zmq
.
asyncio
.
Socket
]
=
None
):
async
def
_send_one_way_rpc_request
(
self
,
request
:
RPC_REQUEST_TYPE
,
error_message
:
str
,
socket
:
Optional
[
Socket
]
=
None
):
"""Send one-way RPC request to trigger an action."""
async
def
do_rpc_call
(
socket
:
zmq
.
asyncio
.
Socket
,
request
:
RPC_REQUEST_TYPE
):
async
def
do_rpc_call
(
socket
:
Socket
,
request
:
RPC_REQUEST_TYPE
):
await
socket
.
send_multipart
(
[
cloudpickle
.
dumps
(
request
)
]
)
await
socket
.
send_multipart
(
(
cloudpickle
.
dumps
(
request
)
,
)
)
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
raise
TimeoutError
(
"Server didn't reply within "
f
"
{
self
.
_data_timeout
}
ms"
)
return
cloudpickle
.
loads
(
await
socket
.
recv
())
frame
=
await
socket
.
recv
(
copy
=
False
)
assert
isinstance
(
frame
,
Frame
)
return
pickle
.
loads
(
frame
.
buffer
)
# Make a new socket connection.
if
socket
is
None
:
...
...
@@ -385,21 +386,20 @@ class AsyncEngineRPCClient:
try
:
with
self
.
to_proxy_socket
()
as
socket
:
# Send RPCGenerateRequest to the RPCServer.
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
RPCGenerateRequest
(
inputs
=
inputs
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
))
])
await
socket
.
send_multipart
((
cloudpickle
.
dumps
(
RPCGenerateRequest
(
inputs
=
inputs
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
)),
))
# Stream back the results from the RPC Server.
while
not
finished
:
message
=
await
socket
.
recv
()
request_output
=
cloudpickle
.
loads
(
message
)
message
=
await
socket
.
recv
(
copy
=
False
)
assert
isinstance
(
message
,
Frame
)
request_output
=
pickle
.
loads
(
message
.
buffer
)
if
isinstance
(
request_output
,
Exception
):
# On exception, check if the server is still healthy
...
...
@@ -423,9 +423,7 @@ class AsyncEngineRPCClient:
if
not
finished
and
not
self
.
_errored
:
await
self
.
abort
(
request_id
)
async
def
check_health
(
self
,
socket
:
Optional
[
zmq
.
asyncio
.
Socket
]
=
None
)
->
None
:
async
def
check_health
(
self
,
socket
:
Optional
[
Socket
]
=
None
)
->
None
:
"""Raise if unhealthy"""
await
self
.
_send_one_way_rpc_request
(
...
...
@@ -450,4 +448,4 @@ class AsyncEngineRPCClient:
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
STOP_PROFILE
,
error_message
=
"RPCRequest STOP_PROFILE failed."
)
\ No newline at end of file
error_message
=
"RPCRequest STOP_PROFILE failed."
)
vllm/entrypoints/openai/rpc/server.py
View file @
0640f227
import
asyncio
import
pickle
import
signal
from
typing
import
Any
,
Coroutine
,
Union
...
...
@@ -7,6 +8,8 @@ import uvloop
import
zmq
import
zmq.asyncio
from
typing_extensions
import
Never
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
...
...
@@ -35,7 +38,7 @@ class AsyncEngineRPCServer:
self
.
context
=
zmq
.
asyncio
.
Context
()
# Init socket.
self
.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
socket
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
socket
.
connect
(
rpc_path
)
...
...
@@ -63,30 +66,31 @@ class AsyncEngineRPCServer:
else
:
raise
ValueError
(
"Unknown Config Request: %s"
,
request
)
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
config
)]
)
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
config
)),
copy
=
False
)
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
is_tracing_enabled
(
self
,
identity
):
"""Send the is_tracing_enabled flag"""
tracing_flag
=
await
self
.
engine
.
is_tracing_enabled
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
tracing_flag
)
]
)
(
identity
,
pickle
.
dumps
(
tracing_flag
)
)
)
async
def
do_log_stats
(
self
,
identity
):
"""Log stats and confirm success."""
await
self
.
engine
.
do_log_stats
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
]
)
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
)
)
async
def
is_server_ready
(
self
,
identity
):
"""Notify the client that we are ready."""
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
]
)
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
)
)
async
def
abort
(
self
,
identity
,
request
:
RPCAbortRequest
):
"""Abort request and notify the client of success."""
...
...
@@ -96,7 +100,7 @@ class AsyncEngineRPCServer:
result
:
Union
[
str
,
Exception
]
=
VLLM_RPC_SUCCESS_STR
except
Exception
as
e
:
result
=
e
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
result
)
]
)
await
self
.
socket
.
send_multipart
(
(
identity
,
pickle
.
dumps
(
result
)
)
)
async
def
generate
(
self
,
identity
,
generate_request
:
RPCGenerateRequest
):
try
:
...
...
@@ -110,45 +114,47 @@ class AsyncEngineRPCServer:
async
for
request_output
in
results_generator
:
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
request_output
)
]
)
(
identity
,
pickle
.
dumps
(
request_output
)
),
copy
=
False
)
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
check_health
(
self
,
identity
):
try
:
await
self
.
engine
.
check_health
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloud
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
]
)
(
identity
,
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)
)
)
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
await
self
.
socket
.
send_multipart
((
identity
,
pickle
.
dumps
(
e
)),
copy
=
False
)
async
def
start_profile
(
self
,
identity
):
logger
.
info
(
"Starting profiler..."
)
await
self
.
engine
.
start_profile
()
logger
.
info
(
"Profiler started."
)
await
self
.
socket
.
send_multipart
(
[
await
self
.
socket
.
send_multipart
(
(
identity
,
cloud
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
]
)
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
)
)
async
def
stop_profile
(
self
,
identity
):
logger
.
info
(
"Stopping profiler..."
)
await
self
.
engine
.
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
await
self
.
socket
.
send_multipart
(
[
await
self
.
socket
.
send_multipart
(
(
identity
,
cloud
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
]
)
pickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
)
)
def
_make_handler_coro
(
self
,
identity
,
message
)
->
Coroutine
[
Any
,
Any
,
Never
]:
message
:
Frame
)
->
Coroutine
[
Any
,
Any
,
Never
]:
"""Route the zmq message to the handler coroutine."""
request
=
cloudpickle
.
loads
(
message
)
request
=
cloudpickle
.
loads
(
message
.
buffer
)
if
isinstance
(
request
,
RPCGenerateRequest
):
return
self
.
generate
(
identity
,
request
)
...
...
@@ -189,7 +195,7 @@ class AsyncEngineRPCServer:
running_tasks
=
set
()
while
True
:
# Wait for a request.
identity
,
message
=
await
self
.
socket
.
recv_multipart
()
identity
,
message
=
await
self
.
socket
.
recv_multipart
(
copy
=
False
)
# Process the request async.
task
=
asyncio
.
create_task
(
...
...
vllm/entrypoints/openai/run_batch.py
View file @
0640f227
...
...
@@ -3,10 +3,11 @@ from io import StringIO
from
typing
import
Awaitable
,
Callable
,
List
import
aiohttp
from
prometheus_client
import
start_http_server
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
,
logger
# yapf: disable
from
vllm.entrypoints.openai.protocol
import
(
BatchRequestInput
,
BatchRequestOutput
,
...
...
@@ -16,13 +17,10 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
def
parse_args
():
parser
=
FlexibleArgumentParser
(
...
...
@@ -59,6 +57,24 @@ def parse_args():
'ID numbers being printed in log.'
'
\n\n
Default: Unlimited'
)
parser
.
add_argument
(
"--enable-metrics"
,
action
=
"store_true"
,
help
=
"Enable Prometheus metrics"
)
parser
.
add_argument
(
"--url"
,
type
=
str
,
default
=
"0.0.0.0"
,
help
=
"URL to the Prometheus metrics server "
"(only needed if enable-metrics is set)."
,
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set)."
,
)
return
parser
.
parse_args
()
...
...
@@ -184,7 +200,15 @@ async def main(args):
if
__name__
==
"__main__"
:
args
=
parse_args
()
logger
.
info
(
"vLLM
API server
version %s"
,
VLLM_VERSION
)
logger
.
info
(
"vLLM
batch processing API
version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint.
if
args
.
enable_metrics
:
logger
.
info
(
"Prometheus metrics enabled"
)
start_http_server
(
port
=
args
.
port
,
addr
=
args
.
url
)
else
:
logger
.
info
(
"Prometheus metrics disabled"
)
asyncio
.
run
(
main
(
args
))
vllm/entrypoints/openai/serving_chat.py
View file @
0640f227
import
asyncio
import
json
import
time
from
typing
import
AsyncGenerator
,
AsyncIterator
,
Dict
,
Final
,
List
,
Optional
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Callable
,
Dict
,
Final
,
List
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
...
...
@@ -11,22 +13,25 @@ from vllm.engine.protocol import AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
apply_chat_template
,
load_chat_template
,
parse_chat_messages
)
parse_chat_messages
_futures
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionLogProb
,
ChatCompletionLogProbs
,
ChatCompletionLogProbsContent
,
ChatCompletionNamedToolChoiceParam
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
Delta
Message
,
ErrorRespons
e
,
FunctionCall
,
ToolCall
,
UsageInfo
)
ChatCompletionStreamResponse
,
ChatMessage
,
Delta
FunctionCall
,
DeltaMessag
e
,
DeltaToolCall
,
ErrorResponse
,
FunctionCall
,
ToolCall
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
,
PromptAdapterPath
)
PromptAdapterPath
,
TextTokensPrompt
)
from
vllm.entrypoints.openai.tool_parsers
import
(
Hermes2ProToolParser
,
MistralToolParser
,
ToolParser
)
from
vllm.inputs
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
...
...
@@ -38,19 +43,19 @@ logger = init_logger(__name__)
class
OpenAIServingChat
(
OpenAIServing
):
def
__init__
(
self
,
async_engine_client
:
AsyncEngineClient
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
]
,
response_role
:
str
,
*
,
lora_module
s
:
Optional
[
List
[
LoRAModule
Path
]],
prompt_adapt
er
s
:
Optional
[
List
[
PromptAdapterPath
]
],
request_logger
:
Optional
[
RequestLogge
r
],
chat_template
:
Optional
[
str
]
,
return_tokens_as_token_id
s
:
bool
=
False
,
):
def
__init__
(
self
,
async_engine_client
:
AsyncEngineClient
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
]
,
response_role
:
str
,
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]]
,
prompt_adapter
s
:
Optional
[
List
[
PromptAdapter
Path
]],
request_logg
er
:
Optional
[
RequestLogger
],
chat_template
:
Optional
[
st
r
],
return_tokens_as_token_ids
:
bool
=
False
,
enable_auto_tool
s
:
bool
=
False
,
tool_parser
:
Optional
[
str
]
=
None
):
super
().
__init__
(
async_engine_client
=
async_engine_client
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
...
...
@@ -60,10 +65,27 @@ class OpenAIServingChat(OpenAIServing):
return_tokens_as_token_ids
=
return_tokens_as_token_ids
)
self
.
response_role
=
response_role
# If this is None we use the tokenizer's default chat template
self
.
use_tool_use_model_template
=
False
self
.
chat_template
=
load_chat_template
(
chat_template
)
# set up tool use
self
.
enable_auto_tools
:
bool
=
enable_auto_tools
if
self
.
enable_auto_tools
:
logger
.
info
(
"
\"
auto
\"
tool choice has been enabled please note that while"
" the parallel_tool_calls client option is preset for "
"compatibility reasons, it will be ignored."
)
self
.
tool_parser
:
Optional
[
Callable
[[
AnyTokenizer
],
ToolParser
]]
=
None
if
self
.
enable_auto_tools
:
if
tool_parser
==
"mistral"
:
self
.
tool_parser
=
MistralToolParser
elif
tool_parser
==
"hermes"
:
self
.
tool_parser
=
Hermes2ProToolParser
else
:
raise
TypeError
(
"Error: --enable-auto-tool-choice requires "
"--tool-call-parser"
)
async
def
create_chat_completion
(
self
,
request
:
ChatCompletionRequest
,
...
...
@@ -76,11 +98,10 @@ class OpenAIServingChat(OpenAIServing):
for the API specification. This API mimics the OpenAI
ChatCompletion API.
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
logger
.
error
(
"Error with model %s"
,
error_check_ret
)
return
error_check_ret
try
:
...
...
@@ -93,7 +114,7 @@ class OpenAIServingChat(OpenAIServing):
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
conversation
,
mm_future
s
=
parse_chat_messages
(
conversation
,
mm_
data_
future
=
parse_chat_messages
_futures
(
request
.
messages
,
model_config
,
tokenizer
)
tool_dicts
=
None
if
request
.
tools
is
None
else
[
...
...
@@ -113,30 +134,47 @@ class OpenAIServingChat(OpenAIServing):
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
mm_data
:
Optional
[
MultiModalDataDict
]
=
None
try
:
if
len
(
mm_futures
):
# since we support only single mm data currently
assert
len
(
mm_futures
)
==
1
,
"Multiple 'image_url' input is currently not supported."
mm_data
=
await
mm_futures
[
0
]
mm_data
=
await
mm_data_future
except
Exception
as
e
:
logger
.
error
(
"Error in loading multi-modal data: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
# validation for OpenAI tools
# tool_choice = "required" is not supported
if
request
.
tool_choice
==
"required"
:
return
self
.
create_error_response
(
"tool_choice =
\"
required
\"
is not supported!"
)
# "auto" tools requires --enable-auto-tool-choice
# and --tool-call-parser
if
request
.
tool_choice
==
"auto"
and
not
(
self
.
enable_auto_tools
and
self
.
tool_parser
is
not
None
):
return
self
.
create_error_response
(
"
\"
auto
\"
tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set"
)
request_id
=
f
"chat-
{
random_uuid
()
}
"
try
:
guided_decode_logits_processor
=
(
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
prompt_inputs
=
self
.
_tokenize_prompt_input
(
request
,
tokenizer
,
prompt
,
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
if
isinstance
(
prompt
,
str
):
prompt_inputs
=
self
.
_tokenize_prompt_input
(
request
,
tokenizer
,
prompt
,
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
else
:
assert
isinstance
(
prompt
,
list
)
and
isinstance
(
prompt
[
0
],
int
),
"Prompt has to be either a string or a list of token ids"
prompt_inputs
=
TextTokensPrompt
(
prompt
=
tokenizer
.
decode
(
prompt
),
prompt_token_ids
=
prompt
)
assert
prompt_inputs
is
not
None
sampling_params
=
request
.
to_sampling_params
(
tokenizer
,
...
...
@@ -184,6 +222,7 @@ class OpenAIServingChat(OpenAIServing):
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
try
:
return
await
self
.
chat_completion_full_generator
(
request
,
result_generator
,
request_id
,
conversation
,
tokenizer
)
...
...
@@ -216,6 +255,9 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens
=
[
0
]
*
num_choices
finish_reason_sent
=
[
False
]
*
num_choices
tool_parser
:
Optional
[
ToolParser
]
=
self
.
tool_parser
(
tokenizer
)
if
self
.
tool_parser
else
None
try
:
async
for
res
in
result_generator
:
# We need to do it here, because if there are exceptions in
...
...
@@ -225,6 +267,9 @@ class OpenAIServingChat(OpenAIServing):
# Send first response for each request.n (index) with
# the role
role
=
self
.
get_chat_request_role
(
request
)
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for
i
in
range
(
num_choices
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
...
...
@@ -237,14 +282,18 @@ class OpenAIServingChat(OpenAIServing):
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
# if usage should be included
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
if
(
request
.
stream_options
.
continuous_usage_stats
):
# if continuous usage stats are requested, add it
if
request
.
stream_options
.
continuous_usage_stats
:
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
prompt_tokens
)
chunk
.
usage
=
usage
# otherwise don't
else
:
chunk
.
usage
=
None
...
...
@@ -254,7 +303,7 @@ class OpenAIServingChat(OpenAIServing):
# Send response to echo the input portion of the
# last message
if
request
.
echo
:
last_msg_content
=
""
last_msg_content
:
Optional
[
str
]
=
""
if
conversation
and
conversation
[
-
1
].
get
(
"content"
)
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
...
...
@@ -295,6 +344,7 @@ class OpenAIServingChat(OpenAIServing):
first_iteration
=
False
for
output
in
res
.
outputs
:
i
=
output
.
index
if
finish_reason_sent
[
i
]:
...
...
@@ -317,20 +367,50 @@ class OpenAIServingChat(OpenAIServing):
logprobs
=
None
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
delta_message
:
Optional
[
DeltaMessage
]
=
None
if
request
.
tool_choice
and
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
:
# handle streaming deltas for tools with named
tool_choice
if
(
request
.
tool_choice
and
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
)
:
delta_message
=
DeltaMessage
(
tool_calls
=
[
ToolCall
(
function
=
FunctionCall
(
Delta
ToolCall
(
function
=
Delta
FunctionCall
(
name
=
request
.
tool_choice
.
function
.
name
,
arguments
=
delta_text
))
arguments
=
delta_text
),
index
=
i
)
])
# handle streaming deltas for tools with "auto" tool choice
elif
(
self
.
_should_stream_with_auto_tool_parsing
(
request
)
and
tool_parser
):
delta_message
=
(
tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
previous_texts
[
i
],
current_text
=
output
.
text
,
delta_text
=
delta_text
,
previous_token_ids
=
\
output
.
token_ids
[
:
-
1
*
len
(
delta_token_ids
)
],
current_token_ids
=
output
.
token_ids
,
delta_token_ids
=
delta_token_ids
)
)
# handle streaming just a content delta
else
:
delta_message
=
DeltaMessage
(
content
=
delta_text
)
# set the previous values for the next iteration
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
# if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise
# wasn't ready to send a token, then
# get the next token without streaming a chunk
if
delta_message
is
None
:
continue
if
output
.
finish_reason
is
None
:
# Send token-by-token response for each request.n
...
...
@@ -345,6 +425,8 @@ class OpenAIServingChat(OpenAIServing):
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
# handle usage stats if requested & if continuous
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
if
(
request
.
stream_options
.
continuous_usage_stats
):
...
...
@@ -362,14 +444,55 @@ class OpenAIServingChat(OpenAIServing):
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
# if the model is finished generating
else
:
# check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously
# matched by partial json parsing
# only happens if we are NOT using guided decoding
if
tool_parser
:
index
=
len
(
tool_parser
.
prev_tool_call_arr
)
-
1
if
len
(
tool_parser
.
prev_tool_call_arr
)
>
0
else
0
else
:
index
=
0
if
self
.
_should_check_for_unstreamed_tool_arg_tokens
(
delta_message
,
output
)
and
tool_parser
:
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON
expected_call
=
json
.
dumps
(
tool_parser
.
prev_tool_call_arr
[
index
].
get
(
"arguments"
,
{}))
# get what we've streamed so for for arguments
# for the current tool
actual_call
=
tool_parser
.
streamed_args_for_tool
[
index
]
# check to see if there's anything left to stream
remaining_call
=
expected_call
.
replace
(
actual_call
,
""
,
1
)
# set that as a delta message
delta_message
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
index
,
function
=
DeltaFunctionCall
(
arguments
=
remaining_call
).
model_dump
(
exclude_none
=
True
))
])
# Send the finish response for each request.n only once
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
delta_message
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
finish_reason
=
output
.
finish_reason
if
not
(
tool_parser
and
len
(
tool_parser
.
prev_tool_call_arr
))
else
"tool_calls"
,
stop_reason
=
output
.
stop_reason
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
...
...
@@ -395,6 +518,8 @@ class OpenAIServingChat(OpenAIServing):
yield
f
"data:
{
data
}
\n\n
"
finish_reason_sent
[
i
]
=
True
# once the final token is handled, if stream_options.include_usage
# is sent, send the usage
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
final_usage
=
UsageInfo
(
...
...
@@ -416,6 +541,7 @@ class OpenAIServingChat(OpenAIServing):
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
logger
.
error
(
"error in chat completion stream generator: %s"
,
e
)
data
=
self
.
create_streaming_error_response
(
str
(
e
))
yield
f
"data:
{
data
}
\n\n
"
# Send the final done message after all response.n are finished
...
...
@@ -460,8 +586,21 @@ class OpenAIServingChat(OpenAIServing):
else
:
logprobs
=
None
if
request
.
tool_choice
and
type
(
# by default, tools are not used.
tools_called
=
False
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
if
not
(
self
.
enable_auto_tools
or
not
self
.
tool_parser
)
and
not
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
):
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
)
# if the request uses tools and specified a tool choice
elif
request
.
tool_choice
and
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
:
message
=
ChatMessage
(
role
=
role
,
content
=
""
,
...
...
@@ -470,14 +609,47 @@ class OpenAIServingChat(OpenAIServing):
name
=
request
.
tool_choice
.
function
.
name
,
arguments
=
output
.
text
))
])
tools_called
=
True
# if the request doesn't use tool choice
# OR specifies to not use a tool
elif
not
request
.
tool_choice
or
request
.
tool_choice
==
"none"
:
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
)
# handle when there are tools and tool choice is auto
elif
request
.
tools
and
(
request
.
tool_choice
==
"auto"
or
request
.
tool_choice
is
None
)
and
self
.
enable_auto_tools
\
and
self
.
tool_parser
:
tool_parser
=
self
.
tool_parser
(
tokenizer
)
tool_call_info
=
tool_parser
.
extract_tool_calls
(
output
.
text
)
tools_called
=
tool_call_info
.
tools_called
if
tool_call_info
.
tools_called
:
message
=
ChatMessage
(
role
=
role
,
content
=
tool_call_info
.
content
,
tool_calls
=
tool_call_info
.
tool_calls
)
else
:
# FOR NOW make it a chat message; we will have to detect
# the type to make it later.
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
)
# undetermined case that is still important to handle
else
:
logger
.
error
(
"Error in chat_completion_full_generator - cannot determine"
" if tools should be extracted. Returning a standard chat "
"completion."
)
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
)
choice_data
=
ChatCompletionResponseChoice
(
index
=
output
.
index
,
message
=
message
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
finish_reason
=
"tool_calls"
if
tools_called
else
output
.
finish_reason
if
output
.
finish_reason
else
"stop"
,
stop_reason
=
output
.
stop_reason
)
choices
.
append
(
choice_data
)
...
...
@@ -485,10 +657,11 @@ class OpenAIServingChat(OpenAIServing):
last_msg_content
=
""
if
conversation
and
conversation
[
-
1
].
get
(
"content"
)
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
last_msg_content
=
conversation
[
-
1
][
"content"
]
last_msg_content
=
conversation
[
-
1
][
"content"
]
or
""
for
choice
in
choices
:
full_message
=
last_msg_content
+
choice
.
message
.
content
full_message
=
last_msg_content
+
(
choice
.
message
.
content
or
""
)
choice
.
message
.
content
=
full_message
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
...
...
@@ -571,3 +744,38 @@ class OpenAIServingChat(OpenAIServing):
))
return
ChatCompletionLogProbs
(
content
=
logprobs_content
)
def
_should_stream_with_auto_tool_parsing
(
self
,
request
:
ChatCompletionRequest
):
"""
Utility function to check if streamed tokens should go through the tool
call parser that was configured.
We only want to do this IF user-provided tools are set, a tool parser
is configured, "auto" tool choice is enabled, and the request's tool
choice field indicates that "auto" tool choice should be used.
"""
return
(
request
.
tools
and
self
.
tool_parser
and
self
.
enable_auto_tools
and
request
.
tool_choice
in
[
'auto'
,
None
])
def
_should_check_for_unstreamed_tool_arg_tokens
(
self
,
delta_message
:
Optional
[
DeltaMessage
],
output
:
CompletionOutput
,
)
->
bool
:
"""
Check to see if we should check for unstreamed tool arguments tokens.
This is only applicable when auto tool parsing is enabled, the delta
is a tool call with arguments.
"""
# yapf: disable
return
bool
(
# if there is a delta message that includes tool calls which
# include a function that has arguments
self
.
enable_auto_tools
and
self
.
tool_parser
and
delta_message
and
delta_message
.
tool_calls
and
delta_message
.
tool_calls
[
0
]
and
delta_message
.
tool_calls
[
0
].
function
and
delta_message
.
tool_calls
[
0
].
function
.
arguments
is
not
None
and
output
.
finish_reason
is
not
None
)
Prev
1
…
4
5
6
7
8
9
10
11
12
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment