Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
711aa9d5
Commit
711aa9d5
authored
Jul 30, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.0' into v0.10.0-dev
parents
751c492c
6d8d0a24
Changes
519
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
685 additions
and
781 deletions
+685
-781
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+5
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+70
-91
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+7
-31
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
...ibuted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
+194
-159
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+30
-4
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+172
-218
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+6
-10
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+25
-65
vllm/engine/metrics.py
vllm/engine/metrics.py
+0
-66
vllm/engine/metrics_types.py
vllm/engine/metrics_types.py
+1
-11
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+0
-4
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+1
-8
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+6
-8
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+0
-5
vllm/engine/protocol.py
vllm/engine/protocol.py
+6
-2
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+146
-46
vllm/entrypoints/cli/main.py
vllm/entrypoints/cli/main.py
+0
-11
vllm/entrypoints/cli/openai.py
vllm/entrypoints/cli/openai.py
+6
-3
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+10
-37
No files found.
Too many changes to show.
To preserve performance only
519 of 519+
files are displayed.
Plain diff
Email patch
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
711aa9d5
...
...
@@ -47,7 +47,10 @@ class MultiConnector(KVConnectorBase_V1):
assert
ktcs
is
not
None
for
ktc
in
ktcs
:
temp_config
=
copy
.
copy
(
vllm_config
)
temp_config
.
kv_transfer_config
=
KVTransferConfig
(
**
ktc
)
engine_id
=
ktc
.
get
(
"engine_id"
,
vllm_config
.
kv_transfer_config
.
engine_id
)
temp_config
.
kv_transfer_config
=
KVTransferConfig
(
**
ktc
,
engine_id
=
engine_id
)
self
.
_connectors
.
append
(
KVConnectorFactory
.
create_connector_v1
(
temp_config
,
role
))
...
...
@@ -187,7 +190,7 @@ class MultiConnector(KVConnectorBase_V1):
async_saves
+=
1
if
txfer_params
is
not
None
:
if
kv_txfer_params
is
not
None
:
#TODO we can probably change this to merge the dicts here,
#
TODO we can probably change this to merge the dicts here,
# checking for key clashes.
raise
RuntimeError
(
"Only one connector can produce KV transfer params"
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
711aa9d5
...
...
@@ -79,7 +79,8 @@ class ReqMeta:
class
NixlConnectorMetadata
(
KVConnectorMetadata
):
def
__init__
(
self
):
self
.
requests
:
dict
[
ReqId
,
ReqMeta
]
=
{}
self
.
reqs_to_recv
:
dict
[
ReqId
,
ReqMeta
]
=
{}
self
.
reqs_to_send
:
dict
[
ReqId
,
float
]
=
{}
def
add_new_req
(
self
,
...
...
@@ -87,7 +88,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
local_block_ids
:
list
[
int
],
kv_transfer_params
:
dict
[
str
,
Any
],
):
self
.
req
uests
[
request_id
]
=
ReqMeta
(
self
.
req
s_to_recv
[
request_id
]
=
ReqMeta
(
local_block_ids
=
local_block_ids
,
remote_block_ids
=
kv_transfer_params
[
"remote_block_ids"
],
remote_engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
...
...
@@ -194,10 +195,12 @@ class NixlConnectorScheduler:
vllm_config
.
parallel_config
.
tensor_parallel_size
)
logger
.
info
(
"Initializing NIXL Scheduler %s"
,
engine_id
)
# Requests that need to start recv.
# Requests that need to start recv
/send
.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self
.
_reqs_need_recv
:
dict
[
ReqId
,
tuple
[
Request
,
list
[
int
]]]
=
{}
# Reqs to send and their expiration time
self
.
_reqs_need_send
:
dict
[
ReqId
,
float
]
=
{}
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
...
...
@@ -284,6 +287,9 @@ class NixlConnectorScheduler:
# Clear the list once workers start the transfers
self
.
_reqs_need_recv
.
clear
()
meta
.
reqs_to_send
=
self
.
_reqs_need_send
self
.
_reqs_need_send
=
{}
return
meta
def
request_finished
(
...
...
@@ -325,6 +331,11 @@ class NixlConnectorScheduler:
# If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks
=
len
(
computed_block_ids
)
>
0
if
delay_free_blocks
:
# Prefill request on remote. It will be read from D upon completion
self
.
_reqs_need_send
[
request
.
request_id
]
=
time
.
perf_counter
(
)
+
envs
.
VLLM_NIXL_ABORT_REQUEST_TIMEOUT
return
delay_free_blocks
,
dict
(
do_remote_prefill
=
True
,
do_remote_decode
=
False
,
...
...
@@ -394,14 +405,8 @@ class NixlConnectorWorker:
# In progress transfers.
# [req_id -> list[handle]]
self
.
_recving_transfers
=
defaultdict
[
ReqId
,
list
[
Transfer
]](
list
)
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
# [req_id -> count]
self
.
_done_recving_count
:
defaultdict
[
ReqId
,
int
]
=
defaultdict
(
lambda
:
0
)
self
.
_done_sending_count
:
defaultdict
[
ReqId
,
int
]
=
defaultdict
(
lambda
:
0
)
# Track the expiration time of requests that are waiting to be sent.
self
.
_reqs_to_send
:
dict
[
ReqId
,
float
]
=
{}
# Background thread for handling new handshake requests.
self
.
_nixl_handshake_listener_t
:
Optional
[
threading
.
Thread
]
=
None
...
...
@@ -475,8 +480,13 @@ class NixlConnectorWorker:
"Connection listener got unexpected message %s"
,
msg
)
sock
.
send_multipart
((
identity
,
b
""
,
encoded_data
))
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
)
->
dict
[
int
,
str
]:
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
,
expected_engine_id
:
str
,
)
->
dict
[
int
,
str
]:
"""Do a NIXL handshake with a remote instance."""
start_time
=
time
.
perf_counter
()
...
...
@@ -485,26 +495,6 @@ class NixlConnectorWorker:
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
def
handshake
(
path
:
str
,
rank
:
int
)
->
str
:
# Send query for the request.
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
sock
.
send
(
GET_META_MSG
)
metadata_bytes
=
sock
.
recv
()
decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
metadata
=
decoder
.
decode
(
metadata_bytes
)
got_metadata_time
=
time
.
perf_counter
()
# Register Remote agent.
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
rank
,
remote_tp_size
)
setup_agent_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
)
logger
.
debug
(
"NIXL handshake: add agent took: %s"
,
setup_agent_time
-
got_metadata_time
)
return
remote_agent_name
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio
=
self
.
_tp_size
[
self
.
engine_id
]
//
remote_tp_size
...
...
@@ -512,8 +502,32 @@ class NixlConnectorWorker:
path
=
make_zmq_path
(
"tcp"
,
host
,
port
+
p_remote_rank
)
logger
.
debug
(
"Querying metadata on path: %s at remote rank %s"
,
path
,
p_remote_rank
)
# Send query for the request.
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
sock
.
send
(
GET_META_MSG
)
metadata_bytes
=
sock
.
recv
()
decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
metadata
=
decoder
.
decode
(
metadata_bytes
)
got_metadata_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
)
# Ensure engine id matches.
if
metadata
.
engine_id
!=
expected_engine_id
:
raise
RuntimeError
(
f
"Remote NIXL agent engine ID mismatch. "
f
"Expected
{
expected_engine_id
}
,"
f
"received
{
metadata
.
engine_id
}
."
)
# Register Remote agent.
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
p_remote_rank
,
remote_tp_size
)
setup_agent_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: add agent took: %s"
,
setup_agent_time
-
got_metadata_time
)
# Remote rank -> agent name.
return
{
p_remote_rank
:
handshake
(
path
,
p_remote_rank
)
}
return
{
p_remote_rank
:
remote_agent_name
}
def
_background_nixl_handshake
(
self
,
req_id
:
str
,
remote_engine_id
:
EngineId
,
meta
:
ReqMeta
):
...
...
@@ -522,7 +536,7 @@ class NixlConnectorWorker:
if
fut
is
None
:
fut
=
self
.
_handshake_initiation_executor
.
submit
(
self
.
_nixl_handshake
,
meta
.
remote_host
,
meta
.
remote_port
,
meta
.
tp_size
)
meta
.
tp_size
,
remote_engine_id
)
self
.
_handshake_futures
[
remote_engine_id
]
=
fut
def
done_callback
(
f
:
Future
[
dict
[
int
,
str
]],
eid
=
remote_engine_id
):
...
...
@@ -725,10 +739,10 @@ class NixlConnectorWorker:
if
remote_tp_rank
in
self
.
_remote_agents
.
get
(
engine_id
,
{}):
return
self
.
_remote_agents
[
engine_id
][
remote_tp_rank
]
if
engine_id
in
self
.
_tp_size
:
assert
self
.
_tp_size
[
engine_id
]
==
remote_tp_size
else
:
if
engine_id
not
in
self
.
_tp_size
:
self
.
_tp_size
[
engine_id
]
=
remote_tp_size
else
:
assert
self
.
_tp_size
[
engine_id
]
==
remote_tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert
nixl_agent_meta
.
attn_backend_name
==
self
.
backend_name
...
...
@@ -808,15 +822,9 @@ class NixlConnectorWorker:
def
get_finished
(
self
)
->
tuple
[
set
[
str
],
set
[
str
]]:
"""
Get requests that are done sending or recving.
In TP>1 setup, each rank exchanges KVs with its counterpart
ranks independently. get_finished() runs in a worker creates
the done_sending and done_recving sets that are sent to the
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
are done before adding to finished, Ranks 1 to N-1 communicate
to Rank 0 once their transaction is done + Rank 0 returns
finished sets to Scheduler only once all ranks are done.
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
done_sending
=
self
.
_get_new_notifs
()
done_recving
=
self
.
_pop_done_transfers
(
self
.
_recving_transfers
)
...
...
@@ -826,50 +834,17 @@ class NixlConnectorWorker:
"and %s requests done recving"
,
self
.
tp_rank
,
len
(
done_sending
),
len
(
done_recving
))
if
self
.
world_size
==
1
:
return
done_sending
,
done_recving
# Rank 0: get finished from all other ranks.
if
self
.
tp_rank
==
0
:
for
req_id
in
done_sending
:
self
.
_done_sending_count
[
req_id
]
+=
1
for
req_id
in
done_recving
:
self
.
_done_recving_count
[
req_id
]
+=
1
# Keep track of how many other ranks have finished.
other_ranks_finished_ids
:
list
[
str
]
=
[]
for
i
in
range
(
1
,
self
.
world_size
):
other_ranks_finished_ids
.
extend
(
self
.
tp_group
.
recv_object
(
src
=
i
))
for
req_id
in
other_ranks_finished_ids
:
if
(
req_id
in
self
.
_done_recving_count
or
req_id
in
self
.
_recving_transfers
):
self
.
_done_recving_count
[
req_id
]
+=
1
else
:
self
.
_done_sending_count
[
req_id
]
+=
1
# Return ids that finished on all ranks to the scheduler.
all_done_recving
:
set
[
str
]
=
set
()
for
req_id
in
list
(
self
.
_done_recving_count
.
keys
()):
if
self
.
_done_recving_count
[
req_id
]
==
self
.
world_size
:
del
self
.
_done_recving_count
[
req_id
]
all_done_recving
.
add
(
req_id
)
all_done_sending
:
set
[
str
]
=
set
()
for
req_id
in
list
(
self
.
_done_sending_count
.
keys
()):
if
self
.
_done_sending_count
[
req_id
]
==
self
.
world_size
:
del
self
.
_done_sending_count
[
req_id
]
all_done_sending
.
add
(
req_id
)
# Handle timeout to avoid stranding blocks on remote.
now
=
time
.
perf_counter
()
while
self
.
_reqs_to_send
:
req_id
,
expires
=
next
(
iter
(
self
.
_reqs_to_send
.
items
()))
# Sorted dict, oldest requests are put first so we can exit early.
if
now
<
expires
:
break
del
self
.
_reqs_to_send
[
req_id
]
done_sending
.
add
(
req_id
)
return
all_done_sending
,
all_done_recving
# Ranks 1 to N-1: send finished ids to Rank 0.
else
:
finished_req_ids
=
list
(
done_recving
.
union
(
done_sending
))
self
.
tp_group
.
send_object
(
finished_req_ids
,
dst
=
0
)
# Unused as only Rank 0 results are sent to scheduler.
return
done_sending
,
done_recving
return
done_sending
,
done_recving
def
_get_new_notifs
(
self
)
->
set
[
str
]:
"""
...
...
@@ -887,6 +862,7 @@ class NixlConnectorWorker:
tp_ratio
):
notified_req_ids
.
add
(
req_id
)
del
self
.
consumer_notification_counts_by_req
[
req_id
]
del
self
.
_reqs_to_send
[
req_id
]
return
notified_req_ids
def
_pop_done_transfers
(
...
...
@@ -921,7 +897,7 @@ class NixlConnectorWorker:
Start loading by triggering non-blocking nixl_xfer.
We check for these trnxs to complete in each step().
"""
for
req_id
,
meta
in
metadata
.
req
uests
.
items
():
for
req_id
,
meta
in
metadata
.
req
s_to_recv
.
items
():
remote_engine_id
=
meta
.
remote_engine_id
logger
.
debug
(
"start_load_kv for request %s from remote engine %s. "
...
...
@@ -943,6 +919,9 @@ class NixlConnectorWorker:
while
not
self
.
_ready_requests
.
empty
():
self
.
_read_blocks_for_req
(
*
self
.
_ready_requests
.
get_nowait
())
# Add to requests that are waiting to be read and track expiration.
self
.
_reqs_to_send
.
update
(
metadata
.
reqs_to_send
)
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
logger
.
debug
(
"Remote agent %s available, calling _read_blocks for req %s"
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
711aa9d5
...
...
@@ -13,7 +13,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
from
vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine
import
(
P2pNcclEngine
)
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -238,32 +237,16 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert
self
.
p2p_nccl_engine
is
not
None
def
extract_kv_from_layer
(
layer
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
...]
num_pages
,
page_size
=
layer
.
shape
[
1
],
layer
.
shape
[
2
]
return
layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)[:,
slot_mapping
,
...]
connector_metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
connector_metadata
,
P2pNcclConnectorMetadata
)
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_layer
,
remote_address
,
request
.
slot_mapping
,
isinstance
(
attn_metadata
,
MLACommonMetadata
))
def
wait_for_save
(
self
):
if
self
.
is_producer
:
...
...
@@ -286,9 +269,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert
self
.
p2p_nccl_engine
is
not
None
forward_context
:
ForwardContext
=
get_forward_context
()
no_compile_layers
=
(
self
.
_vllm_config
.
compilation_config
.
static_forward_context
)
return
self
.
p2p_nccl_engine
.
get_finished
(
finished_req_ids
,
forward_context
)
no_compile_layers
)
# ==============================
# Scheduler-side methods
...
...
@@ -418,14 +402,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_ids
=
block_ids
,
block_size
=
self
.
_block_size
)
# Requests loaded asynchronously are not in the scheduler_output.
# for request_id in self._requests_need_load:
# request, block_ids = self._requests_need_load[request_id]
# meta.add_request(request_id=request.request_id,
# token_ids=request.prompt_token_ids,
# block_ids=block_ids,
# block_size=self._block_size)
self
.
_requests_need_load
.
clear
()
return
meta
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
View file @
711aa9d5
...
...
@@ -8,7 +8,8 @@ import time
import
typing
from
collections
import
deque
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
import
msgpack
import
torch
...
...
@@ -21,9 +22,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import
TensorMemoryPool
)
from
vllm.utils
import
current_stream
,
get_ip
if
TYPE_CHECKING
:
from
vllm.forward_context
import
ForwardContext
logger
=
logging
.
getLogger
(
__name__
)
DEFAULT_MEM_POOL_SIZE_GB
=
32
...
...
@@ -59,6 +57,15 @@ def set_p2p_nccl_context(num_channels: str):
os
.
environ
.
pop
(
var
,
None
)
@
dataclass
class
SendQueueItem
:
tensor_id
:
str
remote_address
:
str
tensor
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
is_mla
:
bool
class
P2pNcclEngine
:
def
__init__
(
self
,
...
...
@@ -112,24 +119,26 @@ class P2pNcclEngine:
self
.
send_stream
=
torch
.
cuda
.
Stream
()
self
.
recv_stream
=
torch
.
cuda
.
Stream
()
mem_pool_size_gb
=
self
.
config
.
get_from_extra_config
(
"mem_pool_size_gb"
,
DEFAULT_MEM_POOL_SIZE_GB
)
self
.
pool
=
TensorMemoryPool
(
max_block_size
=
int
(
mem_pool_size_gb
)
*
1024
**
3
)
# GB
mem_pool_size_gb
=
float
(
self
.
config
.
get_from_extra_config
(
"mem_pool_size_gb"
,
DEFAULT_MEM_POOL_SIZE_GB
))
self
.
pool
=
TensorMemoryPool
(
max_block_size
=
int
(
mem_pool_size_gb
*
1024
**
3
))
# GB
# The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC.
self
.
send_type
=
self
.
config
.
get_from_extra_config
(
"send_type"
,
"PUT"
)
self
.
send_type
=
self
.
config
.
get_from_extra_config
(
"send_type"
,
"PUT_ASYNC"
)
if
self
.
send_type
==
"GET"
:
# tensor_id: torch.Tensor
self
.
send_store
:
dict
[
str
,
torch
.
Tensor
]
=
{}
else
:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self
.
send_queue
:
deque
[
list
[
Any
]
]
=
deque
()
self
.
send_queue
:
deque
[
SendQueueItem
]
=
deque
()
self
.
send_request_id_to_tensor_ids
:
dict
[
str
,
set
[
str
]]
=
{}
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_send_thread
=
threading
.
Thread
(
target
=
self
.
_
send_async
,
self
.
_send_thread
=
threading
.
Thread
(
target
=
self
.
send_async
,
daemon
=
True
)
self
.
_send_thread
.
start
()
...
...
@@ -146,13 +155,12 @@ class P2pNcclEngine:
"nccl_num_channels"
,
"8"
)
self
.
_listener_thread
=
threading
.
Thread
(
target
=
self
.
_
listen_for_requests
,
daemon
=
True
)
target
=
self
.
listen_for_requests
,
daemon
=
True
)
self
.
_listener_thread
.
start
()
self
.
_ping_thread
=
None
if
port_offset
==
0
and
self
.
proxy_address
!=
""
:
self
.
_ping_thread
=
threading
.
Thread
(
target
=
self
.
_ping
,
daemon
=
True
)
self
.
_ping_thread
=
threading
.
Thread
(
target
=
self
.
ping
,
daemon
=
True
)
self
.
_ping_thread
.
start
()
logger
.
info
(
...
...
@@ -162,7 +170,7 @@ class P2pNcclEngine:
self
.
http_address
,
self
.
zmq_address
,
self
.
proxy_address
,
self
.
send_type
,
self
.
buffer_size_threshold
,
self
.
nccl_num_channels
)
def
_
create_connect
(
self
,
remote_address
:
typing
.
Optional
[
str
]
=
None
):
def
create_connect
(
self
,
remote_address
:
typing
.
Optional
[
str
]
=
None
):
assert
remote_address
is
not
None
if
remote_address
not
in
self
.
socks
:
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
...
...
@@ -184,7 +192,7 @@ class P2pNcclEngine:
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
2
,
unique_id
,
rank
)
self
.
comms
[
remote_address
]
=
(
comm
,
rank
)
logger
.
info
(
"🤝ncclCommInitRank Success, %s👉%s, MyRank:
%s"
,
logger
.
info
(
"🤝ncclCommInitRank Success, %s👉%s, MyRank:%s"
,
self
.
zmq_address
,
remote_address
,
rank
)
return
self
.
socks
[
remote_address
],
self
.
comms
[
remote_address
]
...
...
@@ -194,44 +202,54 @@ class P2pNcclEngine:
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
str
]
=
None
,
slot_mapping
:
torch
.
Tensor
=
None
,
is_mla
:
bool
=
False
,
)
->
bool
:
if
remote_address
is
None
:
with
self
.
recv_store_cv
:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
recv_store_cv
.
notify
()
return
True
else
:
if
self
.
send_type
==
"PUT"
:
return
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
elif
self
.
send_type
==
"PUT_ASYNC"
:
with
self
.
send_queue_cv
:
self
.
send_queue
.
append
([
tensor_id
,
remote_address
,
tensor
])
self
.
send_queue_cv
.
notify
()
else
:
# GET
with
self
.
send_store_cv
:
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
while
(
self
.
buffer_size
+
tensor_size
>
self
.
buffer_size_threshold
):
oldest_tenser_id
=
next
(
iter
(
self
.
send_store
))
oldest_tenser
=
self
.
send_store
.
pop
(
oldest_tenser_id
)
oldest_tenser_size
=
oldest_tenser
.
element_size
(
)
*
oldest_tenser
.
numel
()
self
.
buffer_size
-=
oldest_tenser_size
logger
.
info
(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d"
,
remote_address
,
tensor_id
,
tensor_size
,
self
.
buffer_size
,
oldest_tenser_size
,
self
.
rank
)
self
.
send_store
[
tensor_id
]
=
tensor
self
.
buffer_size
+=
tensor_size
logger
.
debug
(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)"
,
remote_address
,
tensor_id
,
tensor_size
,
tensor
.
shape
,
self
.
rank
,
self
.
buffer_size
,
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
item
=
SendQueueItem
(
tensor_id
=
tensor_id
,
remote_address
=
remote_address
,
tensor
=
tensor
,
slot_mapping
=
slot_mapping
,
is_mla
=
is_mla
)
if
self
.
send_type
==
"PUT"
:
return
self
.
send_sync
(
item
)
if
self
.
send_type
==
"PUT_ASYNC"
:
with
self
.
send_queue_cv
:
self
.
send_queue
.
append
(
item
)
self
.
send_queue_cv
.
notify
()
return
True
# GET
with
self
.
send_store_cv
:
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
while
(
self
.
buffer_size
+
tensor_size
>
self
.
buffer_size_threshold
):
oldest_tenser_id
=
next
(
iter
(
self
.
send_store
))
oldest_tenser
=
self
.
send_store
.
pop
(
oldest_tenser_id
)
oldest_tenser_size
=
oldest_tenser
.
element_size
(
)
*
oldest_tenser
.
numel
()
self
.
buffer_size
-=
oldest_tenser_size
logger
.
info
(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d"
,
remote_address
,
tensor_id
,
tensor_size
,
self
.
buffer_size
,
oldest_tenser_size
,
self
.
rank
)
self
.
send_store
[
tensor_id
]
=
tensor
self
.
buffer_size
+=
tensor_size
logger
.
debug
(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)"
,
remote_address
,
tensor_id
,
tensor_size
,
tensor
.
shape
,
self
.
rank
,
self
.
buffer_size
,
self
.
buffer_size
/
self
.
buffer_size_threshold
*
100
)
return
True
def
recv_tensor
(
...
...
@@ -267,7 +285,7 @@ class P2pNcclEngine:
return
None
if
remote_address
not
in
self
.
socks
:
self
.
_
create_connect
(
remote_address
)
self
.
create_connect
(
remote_address
)
sock
=
self
.
socks
[
remote_address
]
comm
,
rank
=
self
.
comms
[
remote_address
]
...
...
@@ -282,121 +300,121 @@ class P2pNcclEngine:
remote_address
,
tensor_id
,
data
[
"ret"
])
return
None
tensor
=
torch
.
empty
(
data
[
"shape"
],
dtype
=
getattr
(
torch
,
data
[
"dtype"
]),
device
=
self
.
device
)
with
torch
.
cuda
.
stream
(
self
.
recv_stream
):
tensor
=
torch
.
empty
(
data
[
"shape"
],
dtype
=
getattr
(
torch
,
data
[
"dtype"
]),
device
=
self
.
device
)
self
.
_
recv
(
comm
,
tensor
,
rank
^
1
,
self
.
recv_stream
)
self
.
recv
(
comm
,
tensor
,
rank
^
1
,
self
.
recv_stream
)
return
tensor
def
_
listen_for_requests
(
self
):
def
listen_for_requests
(
self
):
while
True
:
socks
=
dict
(
self
.
poller
.
poll
())
if
self
.
router_socket
in
socks
:
remote_address
,
message
=
self
.
router_socket
.
recv_multipart
()
data
=
msgpack
.
loads
(
message
)
if
data
[
"cmd"
]
==
"NEW"
:
unique_id
=
self
.
nccl
.
unique_id_from_bytes
(
bytes
(
data
[
"unique_id"
]))
with
torch
.
cuda
.
device
(
self
.
device
):
rank
=
1
with
set_p2p_nccl_context
(
self
.
nccl_num_channels
):
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
2
,
unique_id
,
rank
)
self
.
comms
[
remote_address
.
decode
()]
=
(
comm
,
rank
)
logger
.
info
(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s"
,
self
.
zmq_address
,
remote_address
.
decode
(),
rank
)
elif
data
[
"cmd"
]
==
"PUT"
:
tensor_id
=
data
[
"tensor_id"
]
try
:
with
torch
.
cuda
.
stream
(
self
.
recv_stream
):
tensor
=
torch
.
empty
(
data
[
"shape"
],
dtype
=
getattr
(
torch
,
data
[
"dtype"
]),
device
=
self
.
device
)
self
.
router_socket
.
send_multipart
(
[
remote_address
,
b
"0"
])
comm
,
rank
=
self
.
comms
[
remote_address
.
decode
()]
self
.
_recv
(
comm
,
tensor
,
rank
^
1
,
self
.
recv_stream
)
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
if
(
self
.
buffer_size
+
tensor_size
>
self
.
buffer_size_threshold
):
# Store Tensor in memory pool
addr
=
self
.
pool
.
store_tensor
(
tensor
)
tensor
=
(
addr
,
tensor
.
dtype
,
tensor
.
shape
)
logger
.
warning
(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d"
,
self
.
zmq_address
,
remote_address
.
decode
(),
data
,
addr
)
else
:
self
.
buffer_size
+=
tensor_size
except
torch
.
cuda
.
OutOfMemoryError
:
self
.
router_socket
.
send_multipart
(
[
remote_address
,
b
"1"
])
tensor
=
None
if
self
.
router_socket
not
in
socks
:
continue
remote_address
,
message
=
self
.
router_socket
.
recv_multipart
()
data
=
msgpack
.
loads
(
message
)
if
data
[
"cmd"
]
==
"NEW"
:
unique_id
=
self
.
nccl
.
unique_id_from_bytes
(
bytes
(
data
[
"unique_id"
]))
with
torch
.
cuda
.
device
(
self
.
device
):
rank
=
1
with
set_p2p_nccl_context
(
self
.
nccl_num_channels
):
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
2
,
unique_id
,
rank
)
self
.
comms
[
remote_address
.
decode
()]
=
(
comm
,
rank
)
logger
.
info
(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s"
,
self
.
zmq_address
,
remote_address
.
decode
(),
rank
)
elif
data
[
"cmd"
]
==
"PUT"
:
tensor_id
=
data
[
"tensor_id"
]
try
:
with
torch
.
cuda
.
stream
(
self
.
recv_stream
):
tensor
=
torch
.
empty
(
data
[
"shape"
],
dtype
=
getattr
(
torch
,
data
[
"dtype"
]),
device
=
self
.
device
)
self
.
router_socket
.
send_multipart
([
remote_address
,
b
"0"
])
comm
,
rank
=
self
.
comms
[
remote_address
.
decode
()]
self
.
recv
(
comm
,
tensor
,
rank
^
1
,
self
.
recv_stream
)
tensor_size
=
tensor
.
element_size
()
*
tensor
.
numel
()
if
(
self
.
buffer_size
+
tensor_size
>
self
.
buffer_size_threshold
):
# Store Tensor in memory pool
addr
=
self
.
pool
.
store_tensor
(
tensor
)
tensor
=
(
addr
,
tensor
.
dtype
,
tensor
.
shape
)
logger
.
warning
(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s"
,
self
.
zmq_address
,
remote_address
.
decode
(),
data
)
with
self
.
recv_store_cv
:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
_have_received_tensor_id
(
tensor_id
)
self
.
recv_store_cv
.
notify
()
elif
data
[
"cmd"
]
==
"GET"
:
tensor_id
=
data
[
"tensor_id"
]
with
self
.
send_store_cv
:
tensor
=
self
.
send_store
.
pop
(
tensor_id
,
None
)
if
tensor
is
not
None
:
data
=
{
"ret"
:
0
,
"shape"
:
tensor
.
shape
,
"dtype"
:
str
(
tensor
.
dtype
).
replace
(
"torch."
,
""
)
}
# LRU
self
.
send_store
[
tensor_id
]
=
tensor
self
.
_have_sent_tensor_id
(
tensor_id
)
else
:
data
=
{
"ret"
:
1
}
self
.
router_socket
.
send_multipart
(
[
remote_address
,
msgpack
.
dumps
(
data
)])
if
data
[
"ret"
]
==
0
:
comm
,
rank
=
self
.
comms
[
remote_address
.
decode
()]
self
.
_send
(
comm
,
tensor
.
to
(
self
.
device
),
rank
^
1
,
self
.
send_stream
)
else
:
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d"
,
self
.
zmq_address
,
remote_address
.
decode
(),
data
,
addr
)
else
:
self
.
buffer_size
+=
tensor_size
except
torch
.
cuda
.
OutOfMemoryError
:
self
.
router_socket
.
send_multipart
([
remote_address
,
b
"1"
])
tensor
=
None
logger
.
warning
(
"🚧Unexpected, Received message from %s, data:%s"
,
remote_address
,
data
)
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s"
,
self
.
zmq_address
,
remote_address
.
decode
(),
data
)
def
_have_sent_tensor_id
(
self
,
tensor_id
:
str
):
with
self
.
recv_store_cv
:
self
.
recv_store
[
tensor_id
]
=
tensor
self
.
have_received_tensor_id
(
tensor_id
)
self
.
recv_store_cv
.
notify
()
elif
data
[
"cmd"
]
==
"GET"
:
tensor_id
=
data
[
"tensor_id"
]
with
self
.
send_store_cv
:
tensor
=
self
.
send_store
.
pop
(
tensor_id
,
None
)
if
tensor
is
not
None
:
data
=
{
"ret"
:
0
,
"shape"
:
tensor
.
shape
,
"dtype"
:
str
(
tensor
.
dtype
).
replace
(
"torch."
,
""
)
}
# LRU
self
.
send_store
[
tensor_id
]
=
tensor
self
.
have_sent_tensor_id
(
tensor_id
)
else
:
data
=
{
"ret"
:
1
}
self
.
router_socket
.
send_multipart
(
[
remote_address
,
msgpack
.
dumps
(
data
)])
if
data
[
"ret"
]
==
0
:
comm
,
rank
=
self
.
comms
[
remote_address
.
decode
()]
self
.
send
(
comm
,
tensor
.
to
(
self
.
device
),
rank
^
1
,
self
.
send_stream
)
else
:
logger
.
warning
(
"🚧Unexpected, Received message from %s, data:%s"
,
remote_address
,
data
)
def
have_sent_tensor_id
(
self
,
tensor_id
:
str
):
request_id
=
tensor_id
.
split
(
'#'
)[
0
]
if
request_id
not
in
self
.
send_request_id_to_tensor_ids
:
self
.
send_request_id_to_tensor_ids
[
request_id
]
=
set
()
self
.
send_request_id_to_tensor_ids
[
request_id
].
add
(
tensor_id
)
def
_
have_received_tensor_id
(
self
,
tensor_id
:
str
):
def
have_received_tensor_id
(
self
,
tensor_id
:
str
):
request_id
=
tensor_id
.
split
(
'#'
)[
0
]
if
request_id
not
in
self
.
recv_request_id_to_tensor_ids
:
self
.
recv_request_id_to_tensor_ids
[
request_id
]
=
set
()
self
.
recv_request_id_to_tensor_ids
[
request_id
].
add
(
tensor_id
)
def
_
send_async
(
self
):
def
send_async
(
self
):
while
True
:
with
self
.
send_queue_cv
:
while
not
self
.
send_queue
:
self
.
send_queue_cv
.
wait
()
tensor_id
,
remote_address
,
tensor
=
self
.
send_queue
.
popleft
()
item
=
self
.
send_queue
.
popleft
()
if
not
self
.
send_queue
:
self
.
send_queue_cv
.
notify
()
self
.
_
send_sync
(
tensor_id
,
tensor
,
remote_address
)
self
.
send_sync
(
item
)
def
wait_for_sent
(
self
):
if
self
.
send_type
==
"PUT_ASYNC"
:
...
...
@@ -409,22 +427,21 @@ class P2pNcclEngine:
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d"
,
duration
*
1000
,
self
.
rank
)
def
_send_sync
(
self
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
str
]
=
None
,
)
->
bool
:
if
remote_address
is
None
:
def
send_sync
(
self
,
item
:
SendQueueItem
)
->
bool
:
if
item
.
remote_address
is
None
:
return
False
if
remote_address
not
in
self
.
socks
:
self
.
_
create_connect
(
remote_address
)
if
item
.
remote_address
not
in
self
.
socks
:
self
.
create_connect
(
item
.
remote_address
)
sock
=
self
.
socks
[
remote_address
]
comm
,
rank
=
self
.
comms
[
remote_address
]
with
self
.
send_stream
:
tensor
=
self
.
extract_kv_from_layer
(
item
.
is_mla
,
item
.
tensor
,
item
.
slot_mapping
)
sock
=
self
.
socks
[
item
.
remote_address
]
comm
,
rank
=
self
.
comms
[
item
.
remote_address
]
data
=
{
"cmd"
:
"PUT"
,
"tensor_id"
:
tensor_id
,
"tensor_id"
:
item
.
tensor_id
,
"shape"
:
tensor
.
shape
,
"dtype"
:
str
(
tensor
.
dtype
).
replace
(
"torch."
,
""
)
}
...
...
@@ -435,20 +452,21 @@ class P2pNcclEngine:
logger
.
error
(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s"
,
self
.
zmq_address
,
remote_address
,
rank
,
data
,
tensor
.
shape
,
self
.
zmq_address
,
item
.
remote_address
,
rank
,
data
,
tensor
.
shape
,
tensor
.
element_size
()
*
tensor
.
numel
()
/
1024
**
3
,
response
.
decode
())
return
False
self
.
_
send
(
comm
,
tensor
.
to
(
self
.
device
),
rank
^
1
,
self
.
send_stream
)
self
.
send
(
comm
,
tensor
.
to
(
self
.
device
),
rank
^
1
,
self
.
send_stream
)
if
self
.
send_type
==
"PUT_ASYNC"
:
self
.
_
have_sent_tensor_id
(
tensor_id
)
self
.
have_sent_tensor_id
(
item
.
tensor_id
)
return
True
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
],
forward_context
:
"ForwardContext"
self
,
finished_req_ids
:
set
[
str
],
no_compile_layers
)
->
tuple
[
Optional
[
set
[
str
]],
Optional
[
set
[
str
]]]:
"""
Notifies worker-side connector ids of requests that have
...
...
@@ -463,7 +481,7 @@ class P2pNcclEngine:
# Clear the buffer upon request completion.
for
request_id
in
finished_req_ids
:
for
layer_name
in
forward_context
.
no_compile_layers
:
for
layer_name
in
no_compile_layers
:
tensor_id
=
request_id
+
"#"
+
layer_name
if
tensor_id
in
self
.
recv_store
:
with
self
.
recv_store_cv
:
...
...
@@ -472,7 +490,6 @@ class P2pNcclEngine:
request_id
,
None
)
self
.
recv_request_id_to_tensor_ids
.
pop
(
request_id
,
None
)
addr
=
0
if
isinstance
(
tensor
,
tuple
):
addr
,
_
,
_
=
tensor
self
.
pool
.
free
(
addr
)
...
...
@@ -485,7 +502,7 @@ class P2pNcclEngine:
return
finished_sending
or
None
,
finished_recving
or
None
def
_
ping
(
self
):
def
ping
(
self
):
sock
=
self
.
context
.
socket
(
zmq
.
DEALER
)
sock
.
setsockopt_string
(
zmq
.
IDENTITY
,
self
.
zmq_address
)
logger
.
debug
(
"ping start, zmq_address:%s"
,
self
.
zmq_address
)
...
...
@@ -499,7 +516,7 @@ class P2pNcclEngine:
sock
.
send
(
msgpack
.
dumps
(
data
))
time
.
sleep
(
3
)
def
_
send
(
self
,
comm
,
tensor
:
torch
.
Tensor
,
dst
:
int
,
stream
=
None
):
def
send
(
self
,
comm
,
tensor
:
torch
.
Tensor
,
dst
:
int
,
stream
=
None
):
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
...
...
@@ -512,7 +529,7 @@ class P2pNcclEngine:
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
stream
.
synchronize
()
def
_
recv
(
self
,
comm
,
tensor
:
torch
.
Tensor
,
src
:
int
,
stream
=
None
):
def
recv
(
self
,
comm
,
tensor
:
torch
.
Tensor
,
src
:
int
,
stream
=
None
):
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
...
...
@@ -531,3 +548,21 @@ class P2pNcclEngine:
self
.
_send_thread
.
join
()
if
self
.
_ping_thread
is
not
None
:
self
.
_ping_thread
.
join
()
@
staticmethod
def
extract_kv_from_layer
(
is_mla
:
bool
,
layer
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if
is_mla
:
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
...]
num_pages
,
page_size
=
layer
.
shape
[
1
],
layer
.
shape
[
2
]
return
layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)[:,
slot_mapping
,
...]
vllm/distributed/parallel_state.py
View file @
711aa9d5
...
...
@@ -240,6 +240,8 @@ class GroupCoordinator:
if
current_platform
.
is_cuda_alike
():
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
elif
current_platform
.
is_xpu
():
self
.
device
=
torch
.
device
(
f
"xpu:
{
local_rank
}
"
)
elif
current_platform
.
is_out_of_tree
():
self
.
device
=
torch
.
device
(
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
)
...
...
@@ -270,6 +272,9 @@ class GroupCoordinator:
self
.
use_custom_op_call
=
(
current_platform
.
is_cuda_alike
()
or
current_platform
.
is_tpu
())
self
.
use_cpu_custom_send_recv
=
(
current_platform
.
is_cpu
()
and
hasattr
(
torch
.
ops
.
_C
,
"init_shm_manager"
))
@
property
def
first_rank
(
self
):
"""Return the global rank of the first process in the group"""
...
...
@@ -381,6 +386,12 @@ class GroupCoordinator:
dim
:
int
)
->
torch
.
Tensor
:
return
self
.
device_communicator
.
all_gather
(
input_
,
dim
)
def
all_gatherv
(
self
,
input_
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
dim
:
int
=
0
,
sizes
:
Optional
[
list
[
int
]]
=
None
):
return
self
.
device_communicator
.
all_gatherv
(
input_
,
dim
,
sizes
)
def
reduce_scatter
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
...
...
@@ -399,6 +410,12 @@ class GroupCoordinator:
else
:
return
self
.
_reduce_scatter_out_place
(
input_
,
dim
)
def
reduce_scatterv
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
sizes
:
Optional
[
list
[
int
]]
=
None
)
->
torch
.
Tensor
:
return
self
.
device_communicator
.
reduce_scatterv
(
input_
,
dim
,
sizes
)
def
_reduce_scatter_out_place
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
Tensor
:
return
self
.
device_communicator
.
reduce_scatter
(
input_
,
dim
)
...
...
@@ -649,6 +666,11 @@ class GroupCoordinator:
dst
=
(
self
.
rank_in_group
+
1
)
%
self
.
world_size
assert
dst
<
self
.
world_size
,
f
"Invalid dst rank (
{
dst
}
)"
if
self
.
use_cpu_custom_send_recv
:
self
.
device_communicator
.
send_tensor_dict
(
# type: ignore
tensor_dict
,
dst
)
return
None
metadata_list
:
list
[
tuple
[
Any
,
Any
]]
=
[]
assert
isinstance
(
tensor_dict
,
...
...
@@ -704,6 +726,10 @@ class GroupCoordinator:
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
if
self
.
use_cpu_custom_send_recv
:
return
self
.
device_communicator
.
recv_tensor_dict
(
# type: ignore
src
)
recv_metadata_list
=
self
.
recv_object
(
src
=
src
)
tensor_dict
:
dict
[
str
,
Any
]
=
{}
for
key
,
value
in
recv_metadata_list
:
...
...
@@ -1318,13 +1344,13 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
def
is_global_first_rank
()
->
bool
:
"""
Check if the current process is the first rank globally across all
Check if the current process is the first rank globally across all
parallelism strategies (PP, TP, DP, EP, etc.).
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
or `get_pp_group().is_first_rank`, this function checks the global rank
across all parallelism dimensions.
Returns:
bool: True if this is the global first rank (rank 0), False otherwise.
Returns True if distributed is not initialized (single process).
...
...
@@ -1353,7 +1379,7 @@ def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
Args:
pg: The process group to analyze
Returns:
int: The total number of nodes
"""
...
...
vllm/engine/arg_utils.py
View file @
711aa9d5
...
...
@@ -10,16 +10,16 @@ import functools
import
json
import
sys
import
threading
import
warnings
from
dataclasses
import
MISSING
,
dataclass
,
fields
,
is_dataclass
from
itertools
import
permutations
from
typing
import
(
Annotated
,
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
,
get_args
,
get_origin
)
from
typing
import
(
TYPE_CHECKING
,
Annotated
,
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
,
get_args
,
get_origin
)
import
regex
as
re
import
torch
from
pydantic
import
TypeAdapter
,
ValidationError
from
typing_extensions
import
TypeIs
,
deprecated
from
typing_extensions
import
TypeIs
import
vllm.envs
as
envs
from
vllm.config
import
(
BlockSize
,
CacheConfig
,
CacheDType
,
CompilationConfig
,
...
...
@@ -27,26 +27,33 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DetailedTraceModules
,
Device
,
DeviceConfig
,
DistributedExecutorBackend
,
GuidedDecodingBackend
,
GuidedDecodingBackendV1
,
HfOverrides
,
KVEventsConfig
,
KVTransferConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
Model
Config
,
Model
DType
,
Model
Impl
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
Pooler
Config
,
P
refixCachingHashAlgo
,
PromptAdapterConfig
,
KVTransferConfig
,
LoadConfig
,
LoadFormat
,
LogprobsMode
,
LoRA
Config
,
Model
Config
,
Model
DType
,
ModelImpl
,
MultiModalConfig
,
Observability
Config
,
P
arallelConfig
,
PoolerConfig
,
PrefixCachingHashAlgo
,
SchedulerConfig
,
SchedulerPolicy
,
SpeculativeConfig
,
TaskOption
,
TokenizerMode
,
TokenizerPoolConfig
,
VllmConfig
,
get_attr_docs
,
get_field
)
from
vllm.executor.executor_base
import
ExecutorBase
TaskOption
,
TokenizerMode
,
VllmConfig
,
get_attr_docs
,
get_field
)
from
vllm.logger
import
init_logger
from
vllm.
model_executor.layers.quantization
import
QuantizationMethods
from
vllm.
platforms
import
CpuArchEnum
,
current_platform
from
vllm.plugins
import
load_general_plugins
from
vllm.reasoning
import
ReasoningParserManager
from
vllm.test_utils
import
MODEL_WEIGHTS_S3_BUCKET
,
MODELS_ON_S3
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
STR_DUAL_CHUNK_FLASH_ATTN_VAL
,
FlexibleArgumentParser
,
GiB_bytes
,
get_ip
,
is_in_ray_actor
)
# yapf: enable
if
TYPE_CHECKING
:
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.usage.usage_lib
import
UsageContext
else
:
ExecutorBase
=
Any
QuantizationMethods
=
Any
UsageContext
=
Any
logger
=
init_logger
(
__name__
)
# object is used to allow for special typing forms
...
...
@@ -59,8 +66,6 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
def
_parse_type
(
val
:
str
)
->
T
:
try
:
if
return_type
is
json
.
loads
and
not
re
.
match
(
"^{.*}$"
,
val
):
return
cast
(
T
,
nullable_kvs
(
val
))
return
return_type
(
val
)
except
ValueError
as
e
:
raise
argparse
.
ArgumentTypeError
(
...
...
@@ -81,47 +86,11 @@ def optional_type(
def
union_dict_and_str
(
val
:
str
)
->
Optional
[
Union
[
str
,
dict
[
str
,
str
]]]:
if
not
re
.
match
(
"^
{.*}$"
,
val
):
if
not
re
.
match
(
r
"(?s)^\s*
{.*}
\s*
$"
,
val
):
return
str
(
val
)
return
optional_type
(
json
.
loads
)(
val
)
@
deprecated
(
"Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
"string instead."
)
def
nullable_kvs
(
val
:
str
)
->
dict
[
str
,
int
]:
"""Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.
Args:
val: String value to be parsed.
Returns:
Dictionary with parsed values.
"""
out_dict
:
dict
[
str
,
int
]
=
{}
for
item
in
val
.
split
(
","
):
kv_parts
=
[
part
.
lower
().
strip
()
for
part
in
item
.
split
(
"="
)]
if
len
(
kv_parts
)
!=
2
:
raise
argparse
.
ArgumentTypeError
(
"Each item should be in the form KEY=VALUE"
)
key
,
value
=
kv_parts
try
:
parsed_value
=
int
(
value
)
except
ValueError
as
exc
:
msg
=
f
"Failed to parse value of item
{
key
}
=
{
value
}
"
raise
argparse
.
ArgumentTypeError
(
msg
)
from
exc
if
key
in
out_dict
and
out_dict
[
key
]
!=
parsed_value
:
raise
argparse
.
ArgumentTypeError
(
f
"Conflicting values specified for key:
{
key
}
"
)
out_dict
[
key
]
=
parsed_value
return
out_dict
def
is_type
(
type_hint
:
TypeHint
,
type
:
TypeHintT
)
->
TypeIs
[
TypeHintT
]:
"""Check if the type hint is a specific type."""
return
type_hint
is
type
or
get_origin
(
type_hint
)
is
type
...
...
@@ -171,6 +140,10 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
return
type_hints
def
is_online_quantization
(
quantization
:
Any
)
->
bool
:
return
quantization
in
[
"inc"
]
@
functools
.
lru_cache
(
maxsize
=
30
)
def
_compute_kwargs
(
cls
:
ConfigType
)
->
dict
[
str
,
Any
]:
cls_docs
=
get_attr_docs
(
cls
)
...
...
@@ -199,14 +172,17 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs
[
name
]
=
{
"default"
:
default
,
"help"
:
help
}
# Set other kwargs based on the type hints
json_tip
=
"""
\n\n
Should either be a valid JSON string or JSON keys
passed individually. For example, the following sets of arguments are
equivalent:
\n\n
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`
\n
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`
\n
Additionally, list elements can be passed individually using '+':
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`
\n
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`
\n\n
"""
json_tip
=
"""Should either be a valid JSON string or JSON keys
passed individually. For example, the following sets of arguments are
equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`
\n
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`
Additionally, list elements can be passed individually using `+`:
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`
\n
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`"""
if
dataclass_cls
is
not
None
:
def
parse_dataclass
(
val
:
str
,
cls
=
dataclass_cls
)
->
Any
:
...
...
@@ -218,7 +194,7 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
raise
argparse
.
ArgumentTypeError
(
repr
(
e
))
from
e
kwargs
[
name
][
"type"
]
=
parse_dataclass
kwargs
[
name
][
"help"
]
+=
json_tip
kwargs
[
name
][
"help"
]
+=
f
"
\n\n
{
json_tip
}
"
elif
contains_type
(
type_hints
,
bool
):
# Creates --no-<name> and --<name> flags
kwargs
[
name
][
"action"
]
=
argparse
.
BooleanOptionalAction
...
...
@@ -254,7 +230,7 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs
[
name
][
"type"
]
=
union_dict_and_str
elif
contains_type
(
type_hints
,
dict
):
kwargs
[
name
][
"type"
]
=
parse_type
(
json
.
loads
)
kwargs
[
name
][
"help"
]
+=
json_tip
kwargs
[
name
][
"help"
]
+=
f
"
\n\n
{
json_tip
}
"
elif
(
contains_type
(
type_hints
,
str
)
or
any
(
is_not_builtin
(
th
)
for
th
in
type_hints
)):
kwargs
[
name
][
"type"
]
=
str
...
...
@@ -320,9 +296,11 @@ class EngineArgs:
tensor_parallel_size
:
int
=
ParallelConfig
.
tensor_parallel_size
data_parallel_size
:
int
=
ParallelConfig
.
data_parallel_size
data_parallel_rank
:
Optional
[
int
]
=
None
data_parallel_start_rank
:
Optional
[
int
]
=
None
data_parallel_size_local
:
Optional
[
int
]
=
None
data_parallel_address
:
Optional
[
str
]
=
None
data_parallel_rpc_port
:
Optional
[
int
]
=
None
data_parallel_hybrid_lb
:
bool
=
False
data_parallel_backend
:
str
=
ParallelConfig
.
data_parallel_backend
enable_expert_parallel
:
bool
=
ParallelConfig
.
enable_expert_parallel
enable_eplb
:
bool
=
ParallelConfig
.
enable_eplb
...
...
@@ -338,7 +316,6 @@ class EngineArgs:
CacheConfig
.
prefix_caching_hash_algo
disable_sliding_window
:
bool
=
ModelConfig
.
disable_sliding_window
disable_cascade_attn
:
bool
=
ModelConfig
.
disable_cascade_attn
use_v2_block_manager
:
bool
=
True
swap_space
:
float
=
CacheConfig
.
swap_space
cpu_offload_gb
:
float
=
CacheConfig
.
cpu_offload_gb
gpu_memory_utilization
:
float
=
CacheConfig
.
gpu_memory_utilization
...
...
@@ -350,6 +327,7 @@ class EngineArgs:
SchedulerConfig
.
long_prefill_token_threshold
max_num_seqs
:
Optional
[
int
]
=
SchedulerConfig
.
max_num_seqs
max_logprobs
:
int
=
ModelConfig
.
max_logprobs
logprobs_mode
:
LogprobsMode
=
ModelConfig
.
logprobs_mode
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
ModelConfig
.
revision
code_revision
:
Optional
[
str
]
=
ModelConfig
.
code_revision
...
...
@@ -362,15 +340,9 @@ class EngineArgs:
enforce_eager
:
bool
=
ModelConfig
.
enforce_eager
max_seq_len_to_capture
:
int
=
ModelConfig
.
max_seq_len_to_capture
disable_custom_all_reduce
:
bool
=
ParallelConfig
.
disable_custom_all_reduce
# The following three fields are deprecated and will be removed in a future
# release. Setting them will have no effect. Please remove them from your
# configurations.
tokenizer_pool_size
:
int
=
TokenizerPoolConfig
.
pool_size
tokenizer_pool_type
:
str
=
TokenizerPoolConfig
.
pool_type
tokenizer_pool_extra_config
:
dict
=
\
get_field
(
TokenizerPoolConfig
,
"extra_config"
)
limit_mm_per_prompt
:
dict
[
str
,
int
]
=
\
get_field
(
MultiModalConfig
,
"limit_per_prompt"
)
interleave_mm_strings
:
bool
=
MultiModalConfig
.
interleave_mm_strings
media_io_kwargs
:
dict
[
str
,
dict
[
str
,
Any
]]
=
get_field
(
MultiModalConfig
,
"media_io_kwargs"
)
...
...
@@ -383,19 +355,14 @@ class EngineArgs:
enable_lora_bias
:
bool
=
LoRAConfig
.
bias_enabled
max_loras
:
int
=
LoRAConfig
.
max_loras
max_lora_rank
:
int
=
LoRAConfig
.
max_lora_rank
default_mm_loras
:
Optional
[
Dict
[
str
,
str
]]
=
\
LoRAConfig
.
default_mm_loras
fully_sharded_loras
:
bool
=
LoRAConfig
.
fully_sharded_loras
max_cpu_loras
:
Optional
[
int
]
=
LoRAConfig
.
max_cpu_loras
lora_target_modules
:
Optional
[
List
[
str
]]
=
LoRAConfig
.
lora_target_modules
lora_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
LoRAConfig
.
lora_dtype
lora_extra_vocab_size
:
int
=
LoRAConfig
.
lora_extra_vocab_size
long_lora_scaling_factors
:
Optional
[
tuple
[
float
,
...]]
=
\
LoRAConfig
.
long_lora_scaling_factors
# PromptAdapter fields
enable_prompt_adapter
:
bool
=
False
max_prompt_adapters
:
int
=
PromptAdapterConfig
.
max_prompt_adapters
max_prompt_adapter_token
:
int
=
\
PromptAdapterConfig
.
max_prompt_adapter_token
device
:
Device
=
DeviceConfig
.
device
num_scheduler_steps
:
int
=
SchedulerConfig
.
num_scheduler_steps
multi_step_stream_outputs
:
bool
=
SchedulerConfig
.
multi_step_stream_outputs
ray_workers_use_nsight
:
bool
=
ParallelConfig
.
ray_workers_use_nsight
...
...
@@ -428,7 +395,6 @@ class EngineArgs:
speculative_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
num_speculative_heads
:
Optional
[
int
]
=
None
qlora_adapter_name_or_path
:
Optional
[
str
]
=
None
show_hidden_metrics_for_version
:
Optional
[
str
]
=
\
ObservabilityConfig
.
show_hidden_metrics_for_version
otlp_traces_endpoint
:
Optional
[
str
]
=
\
...
...
@@ -462,7 +428,6 @@ class EngineArgs:
additional_config
:
dict
[
str
,
Any
]
=
\
get_field
(
VllmConfig
,
"additional_config"
)
enable_reasoning
:
Optional
[
bool
]
=
None
# DEPRECATED
reasoning_parser
:
str
=
DecodingConfig
.
reasoning_backend
use_tqdm_on_load
:
bool
=
LoadConfig
.
use_tqdm_on_load
...
...
@@ -471,6 +436,10 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel
:
bool
=
\
ParallelConfig
.
enable_multimodal_encoder_data_parallel
async_scheduling
:
bool
=
SchedulerConfig
.
async_scheduling
# DEPRECATED
enable_prompt_adapter
:
bool
=
False
def
__post_init__
(
self
):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
...
...
@@ -478,13 +447,6 @@ class EngineArgs:
if
isinstance
(
self
.
compilation_config
,
(
int
,
dict
)):
self
.
compilation_config
=
CompilationConfig
.
from_cli
(
str
(
self
.
compilation_config
))
if
self
.
qlora_adapter_name_or_path
is
not
None
:
warnings
.
warn
(
"The `qlora_adapter_name_or_path` is deprecated "
"and will be removed in v0.10.0. "
,
DeprecationWarning
,
stacklevel
=
2
,
)
# Setup plugins
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
...
...
@@ -531,6 +493,8 @@ class EngineArgs:
**
model_kwargs
[
"max_seq_len_to_capture"
])
model_group
.
add_argument
(
"--max-logprobs"
,
**
model_kwargs
[
"max_logprobs"
])
model_group
.
add_argument
(
"--logprobs-mode"
,
**
model_kwargs
[
"logprobs_mode"
])
model_group
.
add_argument
(
"--disable-sliding-window"
,
**
model_kwargs
[
"disable_sliding_window"
])
model_group
.
add_argument
(
"--disable-cascade-attn"
,
...
...
@@ -597,14 +561,6 @@ class EngineArgs:
**
load_kwargs
[
"ignore_patterns"
])
load_group
.
add_argument
(
"--use-tqdm-on-load"
,
**
load_kwargs
[
"use_tqdm_on_load"
])
load_group
.
add_argument
(
"--qlora-adapter-name-or-path"
,
type
=
str
,
default
=
None
,
help
=
"The `--qlora-adapter-name-or-path` has no effect, do not set"
" it, and it will be removed in v0.10.0."
,
deprecated
=
True
,
)
load_group
.
add_argument
(
'--pt-load-map-location'
,
**
load_kwargs
[
"pt_load_map_location"
])
...
...
@@ -625,15 +581,6 @@ class EngineArgs:
guided_decoding_group
.
add_argument
(
"--guided-decoding-disable-additional-properties"
,
**
guided_decoding_kwargs
[
"disable_additional_properties"
])
guided_decoding_group
.
add_argument
(
"--enable-reasoning"
,
action
=
argparse
.
BooleanOptionalAction
,
deprecated
=
True
,
help
=
"[DEPRECATED] The `--enable-reasoning` flag is deprecated as "
"of v0.9.0. Use `--reasoning-parser` to specify the reasoning "
"parser backend instead. This flag (`--enable-reasoning`) will be "
"removed in v0.10.0. When `--reasoning-parser` is specified, "
"reasoning mode is automatically enabled."
)
guided_decoding_group
.
add_argument
(
"--reasoning-parser"
,
# This choices is a special case because it's not static
...
...
@@ -662,6 +609,11 @@ class EngineArgs:
type
=
int
,
help
=
'Data parallel rank of this instance. '
'When set, enables external load balancer mode.'
)
parallel_group
.
add_argument
(
'--data-parallel-start-rank'
,
'-dpr'
,
type
=
int
,
help
=
'Starting data parallel rank '
'for secondary nodes.'
)
parallel_group
.
add_argument
(
'--data-parallel-size-local'
,
'-dpl'
,
type
=
int
,
...
...
@@ -683,6 +635,9 @@ class EngineArgs:
default
=
'mp'
,
help
=
'Backend for data parallel, either '
'"mp" or "ray".'
)
parallel_group
.
add_argument
(
"--data-parallel-hybrid-lb"
,
**
parallel_kwargs
[
"data_parallel_hybrid_lb"
])
parallel_group
.
add_argument
(
"--enable-expert-parallel"
,
**
parallel_kwargs
[
"enable_expert_parallel"
])
...
...
@@ -736,19 +691,6 @@ class EngineArgs:
cache_group
.
add_argument
(
"--calculate-kv-scales"
,
**
cache_kwargs
[
"calculate_kv_scales"
])
# Tokenizer arguments
tokenizer_kwargs
=
get_kwargs
(
TokenizerPoolConfig
)
tokenizer_group
=
parser
.
add_argument_group
(
title
=
"TokenizerPoolConfig"
,
description
=
TokenizerPoolConfig
.
__doc__
,
)
tokenizer_group
.
add_argument
(
"--tokenizer-pool-size"
,
**
tokenizer_kwargs
[
"pool_size"
])
tokenizer_group
.
add_argument
(
"--tokenizer-pool-type"
,
**
tokenizer_kwargs
[
"pool_type"
])
tokenizer_group
.
add_argument
(
"--tokenizer-pool-extra-config"
,
**
tokenizer_kwargs
[
"extra_config"
])
# Multimodal related configs
multimodal_kwargs
=
get_kwargs
(
MultiModalConfig
)
multimodal_group
=
parser
.
add_argument_group
(
...
...
@@ -765,6 +707,9 @@ class EngineArgs:
multimodal_group
.
add_argument
(
"--disable-mm-preprocessor-cache"
,
**
multimodal_kwargs
[
"disable_mm_preprocessor_cache"
])
multimodal_group
.
add_argument
(
"--interleave-mm-strings"
,
**
multimodal_kwargs
[
"interleave_mm_strings"
])
# LoRA related configs
lora_kwargs
=
get_kwargs
(
LoRAConfig
)
...
...
@@ -789,39 +734,12 @@ class EngineArgs:
"--lora-dtype"
,
**
lora_kwargs
[
"lora_dtype"
],
)
lora_group
.
add_argument
(
"--long-lora-scaling-factors"
,
**
lora_kwargs
[
"long_lora_scaling_factors"
])
lora_group
.
add_argument
(
"--max-cpu-loras"
,
**
lora_kwargs
[
"max_cpu_loras"
])
lora_group
.
add_argument
(
"--fully-sharded-loras"
,
**
lora_kwargs
[
"fully_sharded_loras"
])
# PromptAdapter related configs
prompt_adapter_kwargs
=
get_kwargs
(
PromptAdapterConfig
)
prompt_adapter_group
=
parser
.
add_argument_group
(
title
=
"PromptAdapterConfig"
,
description
=
PromptAdapterConfig
.
__doc__
,
)
prompt_adapter_group
.
add_argument
(
"--enable-prompt-adapter"
,
action
=
argparse
.
BooleanOptionalAction
,
help
=
"If True, enable handling of PromptAdapters."
)
prompt_adapter_group
.
add_argument
(
"--max-prompt-adapters"
,
**
prompt_adapter_kwargs
[
"max_prompt_adapters"
])
prompt_adapter_group
.
add_argument
(
"--max-prompt-adapter-token"
,
**
prompt_adapter_kwargs
[
"max_prompt_adapter_token"
])
# Device arguments
device_kwargs
=
get_kwargs
(
DeviceConfig
)
device_group
=
parser
.
add_argument_group
(
title
=
"DeviceConfig"
,
description
=
DeviceConfig
.
__doc__
,
)
device_group
.
add_argument
(
"--device"
,
**
device_kwargs
[
"device"
],
deprecated
=
True
)
lora_group
.
add_argument
(
"--default-mm-loras"
,
**
lora_kwargs
[
"default_mm_loras"
])
# Speculative arguments
speculative_group
=
parser
.
add_argument_group
(
...
...
@@ -911,6 +829,8 @@ class EngineArgs:
scheduler_group
.
add_argument
(
"--disable-hybrid-kv-cache-manager"
,
**
scheduler_kwargs
[
"disable_hybrid_kv_cache_manager"
])
scheduler_group
.
add_argument
(
"--async-scheduling"
,
**
scheduler_kwargs
[
"async_scheduling"
])
# vLLM arguments
vllm_kwargs
=
get_kwargs
(
VllmConfig
)
...
...
@@ -928,18 +848,15 @@ class EngineArgs:
**
vllm_kwargs
[
"additional_config"
])
# Other arguments
parser
.
add_argument
(
'--use-v2-block-manager'
,
action
=
'store_true'
,
default
=
True
,
deprecated
=
True
,
help
=
'[DEPRECATED] block manager v1 has been '
'removed and SelfAttnBlockSpaceManager (i.e. '
'block manager v2) is now the default. '
'Setting this flag to True or False'
' has no effect on vLLM behavior.'
)
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'Disable logging statistics.'
)
parser
.
add_argument
(
'--enable-prompt-adapter'
,
action
=
'store_true'
,
deprecated
=
True
,
help
=
'[DEPRECATED] Prompt adapter has been '
'removed. Setting this flag to True or False'
' has no effect on vLLM behavior.'
)
return
parser
...
...
@@ -985,12 +902,14 @@ class EngineArgs:
enforce_eager
=
self
.
enforce_eager
,
max_seq_len_to_capture
=
self
.
max_seq_len_to_capture
,
max_logprobs
=
self
.
max_logprobs
,
logprobs_mode
=
self
.
logprobs_mode
,
disable_sliding_window
=
self
.
disable_sliding_window
,
disable_cascade_attn
=
self
.
disable_cascade_attn
,
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
enable_prompt_embeds
=
self
.
enable_prompt_embeds
,
served_model_name
=
self
.
served_model_name
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
interleave_mm_strings
=
self
.
interleave_mm_strings
,
media_io_kwargs
=
self
.
media_io_kwargs
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
config_format
=
self
.
config_format
,
...
...
@@ -1007,14 +926,33 @@ class EngineArgs:
enable_chunked_prefill
=
self
.
enable_chunked_prefill
)
def
validate_tensorizer_args
(
self
):
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
)
for
key
in
self
.
model_loader_extra_config
:
if
key
in
TensorizerConfig
.
_fields
:
self
.
model_loader_extra_config
[
"tensorizer_config"
][
key
]
=
self
.
model_loader_extra_config
[
key
]
def
create_load_config
(
self
)
->
LoadConfig
:
if
self
.
quantization
==
"bitsandbytes"
:
self
.
load_format
=
"bitsandbytes"
if
self
.
load_format
==
"tensorizer"
:
if
hasattr
(
self
.
model_loader_extra_config
,
"to_serializable"
):
self
.
model_loader_extra_config
=
(
self
.
model_loader_extra_config
.
to_serializable
())
self
.
model_loader_extra_config
[
"tensorizer_config"
]
=
{}
self
.
model_loader_extra_config
[
"tensorizer_config"
][
"tensorizer_dir"
]
=
self
.
model
self
.
validate_tensorizer_args
()
return
LoadConfig
(
load_format
=
self
.
load_format
,
download_dir
=
self
.
download_dir
,
device
=
"cpu"
if
is_online_quantization
(
self
.
quantization
)
else
None
,
model_loader_extra_config
=
self
.
model_loader_extra_config
,
ignore_patterns
=
self
.
ignore_patterns
,
use_tqdm_on_load
=
self
.
use_tqdm_on_load
,
...
...
@@ -1056,6 +994,7 @@ class EngineArgs:
def
create_engine_config
(
self
,
usage_context
:
Optional
[
UsageContext
]
=
None
,
headless
:
bool
=
False
,
)
->
VllmConfig
:
"""
Create the VllmConfig.
...
...
@@ -1070,7 +1009,6 @@ class EngineArgs:
If VLLM_USE_V1 is specified by the user but the VllmConfig
is incompatible, we raise an error.
"""
from
vllm.platforms
import
current_platform
current_platform
.
pre_register_and_update
()
device_config
=
DeviceConfig
(
...
...
@@ -1097,9 +1035,16 @@ class EngineArgs:
# Set default arguments for V0 or V1 Engine.
if
use_v1
:
self
.
_set_default_args_v1
(
usage_context
,
model_config
)
# Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1
if
current_platform
.
is_cpu
(
)
and
current_platform
.
get_cpu_architecture
()
in
(
CpuArchEnum
.
POWERPC
,
CpuArchEnum
.
ARM
):
logger
.
info
(
"Chunked prefill is not supported for ARM and POWER CPUs; "
"disabling it for V1 backend."
)
self
.
enable_chunked_prefill
=
False
else
:
self
.
_set_default_args_v0
(
model_config
)
assert
self
.
enable_chunked_prefill
is
not
None
if
envs
.
VLLM_ATTENTION_BACKEND
in
[
STR_DUAL_CHUNK_FLASH_ATTN_VAL
]:
...
...
@@ -1138,15 +1083,41 @@ class EngineArgs:
# but we should not do this here.
placement_group
=
ray
.
util
.
get_current_placement_group
()
assert
not
headless
or
not
self
.
data_parallel_hybrid_lb
,
(
"data_parallel_hybrid_lb is not applicable in "
"headless mode"
)
data_parallel_external_lb
=
self
.
data_parallel_rank
is
not
None
# Local DP rank = 1, use pure-external LB.
if
data_parallel_external_lb
:
assert
self
.
data_parallel_size_local
in
(
1
,
None
),
(
"data_parallel_size_local must be 1 when data_parallel_rank "
"is set"
)
data_parallel_size_local
=
1
# Use full external lb if we have local_size of 1.
self
.
data_parallel_hybrid_lb
=
False
elif
self
.
data_parallel_size_local
is
not
None
:
data_parallel_size_local
=
self
.
data_parallel_size_local
if
self
.
data_parallel_start_rank
and
not
headless
:
# Infer hybrid LB mode.
self
.
data_parallel_hybrid_lb
=
True
if
self
.
data_parallel_hybrid_lb
and
data_parallel_size_local
==
1
:
# Use full external lb if we have local_size of 1.
data_parallel_external_lb
=
True
self
.
data_parallel_hybrid_lb
=
False
if
data_parallel_size_local
==
self
.
data_parallel_size
:
# Disable hybrid LB mode if set for a single node
self
.
data_parallel_hybrid_lb
=
False
self
.
data_parallel_rank
=
self
.
data_parallel_start_rank
or
0
else
:
assert
not
self
.
data_parallel_hybrid_lb
,
(
"data_parallel_size_local must be set to use "
"data_parallel_hybrid_lb."
)
# Local DP size defaults to global DP size if not set.
data_parallel_size_local
=
self
.
data_parallel_size
...
...
@@ -1173,6 +1144,26 @@ class EngineArgs:
self
.
data_parallel_rpc_port
is
not
None
)
else
ParallelConfig
.
data_parallel_rpc_port
if
self
.
async_scheduling
:
# Async scheduling does not work with the uniprocess backend.
if
self
.
distributed_executor_backend
is
None
:
self
.
distributed_executor_backend
=
"mp"
logger
.
info
(
"Using mp-based distributed executor backend "
"for async scheduling."
)
if
self
.
distributed_executor_backend
==
"uni"
:
raise
ValueError
(
"Async scheduling is not supported with "
"uni-process backend."
)
if
self
.
pipeline_parallel_size
>
1
:
raise
ValueError
(
"Async scheduling is not supported with "
"pipeline-parallel-size > 1."
)
# Currently, async scheduling does not support speculative decoding.
# TODO(woosuk): Support it.
if
self
.
speculative_config
is
not
None
:
raise
ValueError
(
"Currently, speculative decoding is not supported with "
"async scheduling."
)
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
...
...
@@ -1183,6 +1174,7 @@ class EngineArgs:
data_parallel_master_ip
=
data_parallel_address
,
data_parallel_rpc_port
=
data_parallel_rpc_port
,
data_parallel_backend
=
self
.
data_parallel_backend
,
data_parallel_hybrid_lb
=
self
.
data_parallel_hybrid_lb
,
enable_expert_parallel
=
self
.
enable_expert_parallel
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
num_redundant_experts
,
...
...
@@ -1216,7 +1208,6 @@ class EngineArgs:
if
self
.
enable_chunked_prefill
and
self
.
pipeline_parallel_size
>
1
:
raise
ValueError
(
"Multi-Step Chunked-Prefill is not supported "
"for pipeline-parallel-size > 1"
)
from
vllm.platforms
import
current_platform
if
current_platform
.
is_cpu
():
logger
.
warning
(
"Multi-Step (--num-scheduler-steps > 1) is "
"currently not supported for CPUs and has been "
...
...
@@ -1254,15 +1245,21 @@ class EngineArgs:
long_prefill_token_threshold
=
self
.
long_prefill_token_threshold
,
disable_hybrid_kv_cache_manager
=
self
.
disable_hybrid_kv_cache_manager
,
async_scheduling
=
self
.
async_scheduling
,
)
if
not
model_config
.
is_multimodal_model
and
self
.
default_mm_loras
:
raise
ValueError
(
"Default modality-specific LoRA(s) were provided for a "
"non multimodal model"
)
lora_config
=
LoRAConfig
(
bias_enabled
=
self
.
enable_lora_bias
,
max_lora_rank
=
self
.
max_lora_rank
,
max_loras
=
self
.
max_loras
,
default_mm_loras
=
self
.
default_mm_loras
,
fully_sharded_loras
=
self
.
fully_sharded_loras
,
lora_extra_vocab_size
=
self
.
lora_extra_vocab_size
,
long_lora_scaling_factors
=
self
.
long_lora_scaling_factors
,
lora_dtype
=
self
.
lora_dtype
,
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
and
self
.
max_cpu_loras
>
0
else
None
,
...
...
@@ -1274,11 +1271,6 @@ class EngineArgs:
load_config
=
self
.
create_load_config
()
prompt_adapter_config
=
PromptAdapterConfig
(
max_prompt_adapters
=
self
.
max_prompt_adapters
,
max_prompt_adapter_token
=
self
.
max_prompt_adapter_token
)
\
if
self
.
enable_prompt_adapter
else
None
decoding_config
=
DecodingConfig
(
backend
=
self
.
guided_decoding_backend
,
disable_fallback
=
self
.
guided_decoding_disable_fallback
,
...
...
@@ -1289,8 +1281,8 @@ class EngineArgs:
)
observability_config
=
ObservabilityConfig
(
show_hidden_metrics_for_version
=
self
.
show_hidden_metrics_for_version
,
show_hidden_metrics_for_version
=
(
self
.
show_hidden_metrics_for_version
)
,
otlp_traces_endpoint
=
self
.
otlp_traces_endpoint
,
collect_detailed_traces
=
self
.
collect_detailed_traces
,
)
...
...
@@ -1306,7 +1298,6 @@ class EngineArgs:
load_config
=
load_config
,
decoding_config
=
decoding_config
,
observability_config
=
observability_config
,
prompt_adapter_config
=
prompt_adapter_config
,
compilation_config
=
self
.
compilation_config
,
kv_transfer_config
=
self
.
kv_transfer_config
,
kv_events_config
=
self
.
kv_events_config
,
...
...
@@ -1366,7 +1357,6 @@ class EngineArgs:
# Skip this check if we are running on a non-GPU platform,
# or if the device capability is not available
# (e.g. in a Ray actor without GPUs).
from
vllm.platforms
import
current_platform
if
(
current_platform
.
is_cuda
()
and
current_platform
.
get_device_capability
()
and
current_platform
.
get_device_capability
().
major
<
8
):
...
...
@@ -1376,34 +1366,16 @@ class EngineArgs:
# No Fp8 KV cache so far.
if
self
.
kv_cache_dtype
!=
"auto"
:
fp8_attention
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
will_use_fa
=
(
current_platform
.
is_cuda
()
and
not
envs
.
is_set
(
"VLLM_ATTENTION_BACKEND"
)
)
or
envs
.
VLLM_ATTENTION_BACKEND
==
"FLASH_ATTN_VLLM_V1"
supported
=
False
if
current_platform
.
is_rocm
():
supported
=
True
elif
fp8_attention
and
will_use_fa
:
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
)
supported
=
flash_attn_supports_fp8
()
supported
=
current_platform
.
is_kv_cache_dtype_supported
(
self
.
kv_cache_dtype
)
int8_attention
=
self
.
kv_cache_dtype
.
startswith
(
"int8"
)
if
int8_attention
:
supported
=
True
if
not
supported
:
_raise_or_fallback
(
feature_name
=
"--kv-cache-dtype"
,
recommend_to_remove
=
False
)
return
False
# No Prompt Adapter so far.
if
self
.
enable_prompt_adapter
:
_raise_or_fallback
(
feature_name
=
"--enable-prompt-adapter"
,
recommend_to_remove
=
False
)
return
False
# No text embedding inputs so far.
if
self
.
enable_prompt_embeds
:
_raise_or_fallback
(
feature_name
=
"--enable-prompt-embeds"
,
...
...
@@ -1437,28 +1409,12 @@ class EngineArgs:
return
False
# V1 supports N-gram, Medusa, and Eagle speculative decoding.
is_ngram_enabled
=
False
is_eagle_enabled
=
False
is_medusa_enabled
=
False
if
self
.
speculative_config
is
not
None
:
# This is supported but experimental (handled below).
speculative_method
=
self
.
speculative_config
.
get
(
"method"
)
if
speculative_method
:
if
speculative_method
in
(
"ngram"
,
"[ngram]"
):
is_ngram_enabled
=
True
elif
speculative_method
==
"medusa"
:
is_medusa_enabled
=
True
elif
speculative_method
in
(
"eagle"
,
"eagle3"
,
"deepseek_mtp"
):
is_eagle_enabled
=
True
else
:
speculative_model
=
self
.
speculative_config
.
get
(
"model"
)
if
speculative_model
in
(
"ngram"
,
"[ngram]"
):
is_ngram_enabled
=
True
if
not
(
is_ngram_enabled
or
is_eagle_enabled
or
is_medusa_enabled
):
# Other speculative decoding methods are not supported yet.
_raise_or_fallback
(
feature_name
=
"Speculative Decoding"
,
recommend_to_remove
=
False
)
return
False
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
get
(
"method"
)
==
"draft_model"
):
raise
NotImplementedError
(
"Speculative decoding with draft model is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp."
)
# No XFormers so far.
V1_BACKENDS
=
[
...
...
@@ -1534,7 +1490,6 @@ class EngineArgs:
# Enable chunked prefill by default for long context (> 32K)
# models to avoid OOM errors in initial memory profiling phase.
elif
use_long_context
:
from
vllm.platforms
import
current_platform
is_gpu
=
current_platform
.
is_cuda
()
use_sliding_window
=
(
model_config
.
get_sliding_window
()
is
not
None
)
...
...
@@ -1542,7 +1497,6 @@ class EngineArgs:
if
(
is_gpu
and
not
use_sliding_window
and
not
use_spec_decode
and
not
self
.
enable_lora
and
not
self
.
enable_prompt_adapter
and
model_config
.
runner_type
!=
"pooling"
):
self
.
enable_chunked_prefill
=
True
logger
.
warning
(
...
...
@@ -1636,7 +1590,6 @@ class EngineArgs:
# as the platform that vLLM is running on (e.g. the case of scaling
# vLLM with Ray) and has no GPUs. In this case we use the default
# values for non-H100/H200 GPUs.
from
vllm.platforms
import
current_platform
try
:
device_memory
=
current_platform
.
get_device_total_memory
()
device_name
=
current_platform
.
get_device_name
().
lower
()
...
...
@@ -1647,6 +1600,7 @@ class EngineArgs:
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details.
# So here we do an extra device name check to prevent such regression.
from
vllm.usage.usage_lib
import
UsageContext
if
device_memory
>=
70
*
GiB_bytes
and
"a100"
not
in
device_name
:
# For GPUs like H100 and MI300x, use larger default values.
default_max_num_batched_tokens
=
{
...
...
@@ -1685,13 +1639,14 @@ class EngineArgs:
# cpu specific default values.
if
current_platform
.
is_cpu
():
world_size
=
self
.
pipeline_parallel_size
*
self
.
tensor_parallel_size
default_max_num_batched_tokens
=
{
UsageContext
.
LLM_CLASS
:
4096
,
UsageContext
.
OPENAI_API_SERVER
:
2048
,
UsageContext
.
LLM_CLASS
:
4096
*
world_size
,
UsageContext
.
OPENAI_API_SERVER
:
2048
*
world_size
,
}
default_max_num_seqs
=
{
UsageContext
.
LLM_CLASS
:
128
,
UsageContext
.
OPENAI_API_SERVER
:
32
,
UsageContext
.
LLM_CLASS
:
256
*
world_size
,
UsageContext
.
OPENAI_API_SERVER
:
128
*
world_size
,
}
use_context_value
=
usage_context
.
value
if
usage_context
else
None
...
...
@@ -1739,7 +1694,6 @@ class AsyncEngineArgs(EngineArgs):
parser
.
add_argument
(
'--disable-log-requests'
,
action
=
'store_true'
,
help
=
'Disable logging requests.'
)
from
vllm.platforms
import
current_platform
current_platform
.
pre_register_and_update
(
parser
)
return
parser
...
...
vllm/engine/async_llm_engine.py
View file @
711aa9d5
...
...
@@ -29,7 +29,6 @@ from vllm.model_executor.guided_decoding import (
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
...
...
@@ -435,9 +434,9 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
None
:
"""
Async version of
...
...
@@ -467,7 +466,7 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
prompt
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
tokenization_kwargs
=
tokenization_kwargs
,
)
if
isinstance
(
params
,
SamplingParams
)
and
\
...
...
@@ -489,7 +488,6 @@ class _AsyncLLMEngine(LLMEngine):
params
=
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
priority
=
priority
,
)
...
...
@@ -859,9 +857,9 @@ class AsyncLLMEngine(EngineClient):
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
PoolingRequestOutput
],
None
]:
if
not
self
.
is_running
:
if
self
.
start_engine_loop
:
...
...
@@ -886,9 +884,9 @@ class AsyncLLMEngine(EngineClient):
arrival_time
=
arrival_time
or
time
.
time
(),
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
data_parallel_rank
=
data_parallel_rank
,
tokenization_kwargs
=
tokenization_kwargs
,
)
return
stream
.
generator
()
...
...
@@ -900,7 +898,6 @@ class AsyncLLMEngine(EngineClient):
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
...
@@ -918,8 +915,6 @@ class AsyncLLMEngine(EngineClient):
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.
data_parallel_rank: The (global) data parallel rank that must
...
...
@@ -979,7 +974,6 @@ class AsyncLLMEngine(EngineClient):
sampling_params
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
data_parallel_rank
=
data_parallel_rank
,
):
...
...
@@ -996,6 +990,7 @@ class AsyncLLMEngine(EngineClient):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
AsyncGenerator
[
PoolingRequestOutput
,
None
]:
"""Generate outputs for a request from a pooling model.
...
...
@@ -1070,6 +1065,7 @@ class AsyncLLMEngine(EngineClient):
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
priority
,
tokenization_kwargs
=
tokenization_kwargs
,
):
yield
LLMEngine
.
validate_output
(
output
,
PoolingRequestOutput
)
except
asyncio
.
CancelledError
:
...
...
vllm/engine/llm_engine.py
View file @
711aa9d5
...
...
@@ -45,7 +45,6 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from
vllm.outputs
import
(
PoolingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
(
ExecuteModelRequest
,
ParallelSampleSequenceGroup
,
PoolingSequenceGroupOutput
,
Sequence
,
SequenceGroup
,
...
...
@@ -227,7 +226,6 @@ class LLMEngine:
self
.
load_config
=
vllm_config
.
load_config
self
.
decoding_config
=
vllm_config
.
decoding_config
or
DecodingConfig
(
# noqa
)
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
# noqa
self
.
observability_config
=
vllm_config
.
observability_config
or
ObservabilityConfig
(
# noqa
)
...
...
@@ -242,18 +240,18 @@ class LLMEngine:
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
if
not
self
.
model_config
.
skip_tokenizer_init
and
self
.
model_config
.
tokenizer_mode
!=
"cpm"
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
if
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
None
self
.
detokenizer
=
None
tokenizer_group
=
None
elif
self
.
model_config
.
tokenizer_mode
==
"cpm"
:
self
.
tokenizer
=
CPM9GTokenizer
(
self
.
model_config
.
model
,
trust_remote_code
=
True
)
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
,
self
.
model_config
.
tokenizer_mode
)
tokenizer_group
=
self
.
get_tokenizer_group
()
else
:
self
.
tokenizer
=
None
self
.
detokenizer
=
None
tokenizer_group
=
None
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
...
...
@@ -302,8 +300,6 @@ class LLMEngine:
# Feature flags
"enable_lora"
:
bool
(
self
.
lora_config
),
"enable_prompt_adapter"
:
bool
(
self
.
prompt_adapter_config
),
"enable_prefix_caching"
:
self
.
cache_config
.
enable_prefix_caching
,
"enforce_eager"
:
...
...
@@ -556,9 +552,6 @@ class LLMEngine:
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
self
.
lora_config
.
verify_with_scheduler_config
(
self
.
scheduler_config
)
if
self
.
prompt_adapter_config
:
self
.
prompt_adapter_config
.
verify_with_model_config
(
self
.
model_config
)
def
_add_processed_request
(
self
,
...
...
@@ -567,7 +560,6 @@ class LLMEngine:
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
Optional
[
SequenceGroup
]:
...
...
@@ -583,7 +575,6 @@ class LLMEngine:
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
)
return
None
...
...
@@ -601,11 +592,10 @@ class LLMEngine:
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
seq
=
Sequence
(
seq_id
,
decoder_inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
)
lora_request
)
encoder_seq
=
(
None
if
encoder_inputs
is
None
else
Sequence
(
seq_id
,
encoder_inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
))
seq_id
,
encoder_inputs
,
block_size
,
eos_token_id
,
lora_request
))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if
isinstance
(
params
,
SamplingParams
):
...
...
@@ -616,7 +606,6 @@ class LLMEngine:
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
,
priority
=
priority
)
elif
isinstance
(
params
,
PoolingParams
):
...
...
@@ -626,7 +615,6 @@ class LLMEngine:
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
,
priority
=
priority
)
else
:
...
...
@@ -655,7 +643,6 @@ class LLMEngine:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
"""Add a request to the engine's request pool.
...
...
@@ -676,7 +663,6 @@ class LLMEngine:
the current monotonic time.
lora_request: The LoRA request to add.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: The prompt adapter request to add.
priority: The priority of the request.
Only applicable with priority scheduling.
...
...
@@ -741,7 +727,6 @@ class LLMEngine:
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
self
.
_add_processed_request
(
...
...
@@ -750,7 +735,6 @@ class LLMEngine:
params
=
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
priority
=
priority
,
)
...
...
@@ -763,7 +747,6 @@ class LLMEngine:
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
priority
:
int
=
0
,
)
->
SequenceGroup
:
...
...
@@ -791,17 +774,15 @@ class LLMEngine:
if
self
.
vllm_config
.
speculative_config
is
not
None
:
draft_size
=
\
self
.
vllm_config
.
speculative_config
.
num_speculative_tokens
+
1
seq_group
=
SequenceGroup
(
request_id
=
request_id
,
seqs
=
[
seq
],
arrival_time
=
arrival_time
,
sampling_params
=
sampling_params
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
,
priority
=
priority
,
draft_size
=
draft_size
)
seq_group
=
SequenceGroup
(
request_id
=
request_id
,
seqs
=
[
seq
],
arrival_time
=
arrival_time
,
sampling_params
=
sampling_params
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
encoder_seq
=
encoder_seq
,
priority
=
priority
,
draft_size
=
draft_size
)
return
seq_group
...
...
@@ -812,7 +793,6 @@ class LLMEngine:
pooling_params
:
PoolingParams
,
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
encoder_seq
:
Optional
[
Sequence
]
=
None
,
priority
:
int
=
0
,
)
->
SequenceGroup
:
...
...
@@ -820,15 +800,13 @@ class LLMEngine:
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params
=
pooling_params
.
clone
()
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
=
request_id
,
seqs
=
[
seq
],
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
pooling_params
=
pooling_params
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
,
priority
=
priority
)
seq_group
=
SequenceGroup
(
request_id
=
request_id
,
seqs
=
[
seq
],
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
pooling_params
=
pooling_params
,
encoder_seq
=
encoder_seq
,
priority
=
priority
)
return
seq_group
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
...
...
@@ -1816,13 +1794,6 @@ class LLMEngine:
num_generation_tokens_from_prefill_groups
)
num_tokens_iter
=
(
num_generation_tokens_iter
+
num_prompt_tokens_iter
)
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if
model_output
and
isinstance
(
model_output
[
0
],
SamplerOutput
)
and
(
model_output
[
0
].
spec_decode_worker_metrics
is
not
None
):
spec_decode_metrics
=
model_output
[
0
].
spec_decode_worker_metrics
else
:
spec_decode_metrics
=
None
return
Stats
(
now
=
now
,
...
...
@@ -1844,7 +1815,6 @@ class LLMEngine:
num_tokens_iter
=
num_tokens_iter
,
time_to_first_tokens_iter
=
time_to_first_tokens_iter
,
time_per_output_tokens_iter
=
time_per_output_tokens_iter
,
spec_decode_metrics
=
spec_decode_metrics
,
num_preemption_iter
=
num_preemption_iter
,
# Request stats
...
...
@@ -1878,16 +1848,6 @@ class LLMEngine:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_executor
.
pin_lora
(
lora_id
)
def
add_prompt_adapter
(
self
,
prompt_adapter_request
:
PromptAdapterRequest
)
->
bool
:
return
self
.
model_executor
.
add_prompt_adapter
(
prompt_adapter_request
)
def
remove_prompt_adapter
(
self
,
prompt_adapter_id
:
int
)
->
bool
:
return
self
.
model_executor
.
remove_prompt_adapter
(
prompt_adapter_id
)
def
list_prompt_adapters
(
self
)
->
List
[
int
]:
return
self
.
model_executor
.
list_prompt_adapters
()
def
start_profile
(
self
)
->
None
:
self
.
model_executor
.
start_profile
()
...
...
vllm/engine/metrics.py
View file @
711aa9d5
...
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
from
typing
import
TYPE_CHECKING
from
typing
import
Counter
as
CollectionsCounter
from
typing
import
Dict
,
List
,
Optional
,
Type
,
Union
,
cast
...
...
@@ -19,9 +18,6 @@ if ray is not None:
else
:
ray_metrics
=
None
if
TYPE_CHECKING
:
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
logger
=
init_logger
(
__name__
)
prometheus_client
.
disable_created_metrics
()
...
...
@@ -199,30 +195,6 @@ class Metrics:
documentation
=
"Count of successfully processed requests."
,
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
# Speculative decoding stats
self
.
gauge_spec_decode_draft_acceptance_rate
=
self
.
_gauge_cls
(
name
=
"vllm:spec_decode_draft_acceptance_rate"
,
documentation
=
"Speulative token acceptance rate."
,
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
self
.
gauge_spec_decode_efficiency
=
self
.
_gauge_cls
(
name
=
"vllm:spec_decode_efficiency"
,
documentation
=
"Speculative decoding system efficiency."
,
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
self
.
counter_spec_decode_num_accepted_tokens
=
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
))
self
.
counter_spec_decode_num_draft_tokens
=
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
documentation
=
"Number of draft tokens."
,
labelnames
=
labelnames
)
self
.
counter_spec_decode_num_emitted_tokens
=
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_emitted_tokens_total"
,
documentation
=
"Number of emitted tokens."
,
labelnames
=
labelnames
))
# --8<-- [end:metrics-definitions]
...
...
@@ -391,9 +363,6 @@ class LoggingStatLogger(StatLoggerBase):
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
# Update spec decode metrics
self
.
maybe_update_spec_decode_metrics
(
stats
)
# Log locally every local_interval seconds.
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
self
.
local_interval
):
...
...
@@ -435,10 +404,6 @@ class LoggingStatLogger(StatLoggerBase):
stats
.
gpu_prefix_cache_hit_rate
*
100
,
stats
.
cpu_prefix_cache_hit_rate
*
100
,
)
if
self
.
spec_decode_metrics
is
not
None
:
log_fn
(
self
.
_format_spec_decode_metrics_str
(
self
.
spec_decode_metrics
))
self
.
_reset
(
stats
,
prompt_throughput
,
generation_throughput
)
...
...
@@ -447,21 +412,9 @@ class LoggingStatLogger(StatLoggerBase):
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
self
.
spec_decode_metrics
=
None
self
.
last_prompt_throughput
=
prompt_throughput
self
.
last_generation_throughput
=
generation_throughput
def
_format_spec_decode_metrics_str
(
self
,
metrics
:
"SpecDecodeWorkerMetrics"
)
->
str
:
return
(
"Speculative metrics: "
f
"Draft acceptance rate:
{
metrics
.
draft_acceptance_rate
:.
3
f
}
, "
f
"System efficiency:
{
metrics
.
system_efficiency
:.
3
f
}
, "
f
"Number of speculative tokens:
{
metrics
.
num_spec_tokens
}
, "
f
"Number of accepted tokens:
{
metrics
.
accepted_tokens
}
, "
f
"Number of draft tokens:
{
metrics
.
draft_tokens
}
, "
f
"Number of emitted tokens:
{
metrics
.
emitted_tokens
}
."
)
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
raise
NotImplementedError
...
...
@@ -579,33 +532,14 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
# Update spec decode metrics
self
.
maybe_update_spec_decode_metrics
(
stats
)
# Log locally every local_interval seconds.
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
self
.
local_interval
):
if
self
.
spec_decode_metrics
is
not
None
:
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
,
self
.
spec_decode_metrics
.
draft_acceptance_rate
)
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_efficiency
,
self
.
spec_decode_metrics
.
system_efficiency
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_accepted_tokens
,
self
.
spec_decode_metrics
.
accepted_tokens
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_draft_tokens
,
self
.
spec_decode_metrics
.
draft_tokens
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_emitted_tokens
,
self
.
spec_decode_metrics
.
emitted_tokens
)
# Reset tracked stats for next interval.
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
self
.
spec_decode_metrics
=
None
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
# Info type metrics are syntactic sugar for a gauge permanently set to 1
...
...
vllm/engine/metrics_types.py
View file @
711aa9d5
...
...
@@ -16,10 +16,9 @@ do this in Python code and lazily import prometheus_client.
import
time
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
from
typing
import
List
from
vllm.config
import
SupportsMetricsInfo
,
VllmConfig
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
@
dataclass
...
...
@@ -65,8 +64,6 @@ class Stats:
running_lora_adapters
:
List
[
str
]
max_lora
:
str
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
class
StatLoggerBase
(
ABC
):
"""Base class for StatLogger."""
...
...
@@ -77,7 +74,6 @@ class StatLoggerBase(ABC):
self
.
num_generation_tokens
:
List
[
int
]
=
[]
self
.
last_local_log
=
time
.
time
()
self
.
local_interval
=
local_interval
self
.
spec_decode_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
@
abstractmethod
def
log
(
self
,
stats
:
Stats
)
->
None
:
...
...
@@ -86,9 +82,3 @@ class StatLoggerBase(ABC):
@
abstractmethod
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
raise
NotImplementedError
def
maybe_update_spec_decode_metrics
(
self
,
stats
:
Stats
):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if
stats
.
spec_decode_metrics
is
not
None
:
self
.
spec_decode_metrics
=
stats
.
spec_decode_metrics
vllm/engine/multiprocessing/__init__.py
View file @
711aa9d5
...
...
@@ -10,7 +10,6 @@ from vllm import PoolingParams
from
vllm.inputs
import
PromptType
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
Device
...
...
@@ -33,7 +32,6 @@ class RPCProcessRequest:
request_id
:
str
lora_request
:
Optional
[
LoRARequest
]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
priority
:
int
=
0
def
__init__
(
...
...
@@ -43,7 +41,6 @@ class RPCProcessRequest:
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
super
().
__init__
()
...
...
@@ -53,7 +50,6 @@ class RPCProcessRequest:
self
.
request_id
=
request_id
self
.
lora_request
=
lora_request
self
.
trace_headers
=
trace_headers
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
priority
=
priority
...
...
vllm/engine/multiprocessing/client.py
View file @
711aa9d5
...
...
@@ -45,7 +45,6 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
...
...
@@ -453,7 +452,6 @@ class MQLLMEngineClient(EngineClient):
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
...
...
@@ -470,8 +468,6 @@ class MQLLMEngineClient(EngineClient):
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
...
...
@@ -479,8 +475,7 @@ class MQLLMEngineClient(EngineClient):
return
cast
(
AsyncGenerator
[
RequestOutput
,
None
],
self
.
_process_request
(
prompt
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
))
lora_request
,
trace_headers
,
priority
))
def
encode
(
self
,
...
...
@@ -526,7 +521,6 @@ class MQLLMEngineClient(EngineClient):
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
Union
[
AsyncGenerator
[
RequestOutput
,
None
],
AsyncGenerator
[
PoolingRequestOutput
,
None
]]:
...
...
@@ -580,7 +574,6 @@ class MQLLMEngineClient(EngineClient):
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
))
...
...
vllm/engine/multiprocessing/engine.py
View file @
711aa9d5
...
...
@@ -322,14 +322,12 @@ class MQLLMEngine:
self
.
_send_outputs
(
rpc_err
)
try
:
self
.
engine
.
add_request
(
request_id
=
request_id
,
prompt
=
request
.
prompt
,
params
=
request
.
params
,
lora_request
=
request
.
lora_request
,
trace_headers
=
request
.
trace_headers
,
prompt_adapter_request
=
request
.
prompt_adapter_request
,
priority
=
request
.
priority
)
self
.
engine
.
add_request
(
request_id
=
request_id
,
prompt
=
request
.
prompt
,
params
=
request
.
params
,
lora_request
=
request
.
lora_request
,
trace_headers
=
request
.
trace_headers
,
priority
=
request
.
priority
)
if
self
.
log_requests
:
logger
.
info
(
"Added request %s."
,
request
.
request_id
)
...
...
vllm/engine/output_processor/multi_step.py
View file @
711aa9d5
...
...
@@ -104,11 +104,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
seqs
=
sequence_group
.
get_seqs
(
status
=
SequenceStatus
.
FINISHED_ABORTED
)
for
output
in
outputs
:
if
output
.
samples
[
0
].
output_token
!=
VLLM_INVALID_TOKEN_ID
:
sequence_group
.
metrics
.
spec_token_acceptance_counts
[
output
.
step_index
]
+=
1
assert
seqs
,
"Expected RUNNING or FINISHED_ABORTED sequences"
assert
len
(
seqs
)
==
1
,
(
"Beam search not supported in multi-step decoding."
)
...
...
vllm/engine/protocol.py
View file @
711aa9d5
...
...
@@ -16,7 +16,6 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
CompletionOutput
,
PoolingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
Device
,
collect_from_async_generator
,
random_uuid
...
...
@@ -55,7 +54,6 @@ class EngineClient(ABC):
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request."""
...
...
@@ -324,3 +322,9 @@ class EngineClient(ABC):
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
"""Load a new LoRA adapter into the engine for future requests."""
...
async
def
scale_elastic_ep
(
self
,
new_data_parallel_size
:
int
,
drain_timeout
:
int
=
300
)
->
None
:
"""Scale the engine"""
raise
NotImplementedError
vllm/entrypoints/chat_utils.py
View file @
711aa9d5
...
...
@@ -4,7 +4,7 @@
import
asyncio
import
json
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
,
deque
from
collections
import
Counter
,
defaultdict
,
deque
from
collections.abc
import
Awaitable
,
Iterable
from
functools
import
cached_property
,
lru_cache
,
partial
from
pathlib
import
Path
...
...
@@ -28,6 +28,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam
)
from
openai.types.chat.chat_completion_content_part_input_audio_param
import
(
InputAudio
)
from
openai.types.responses
import
ResponseInputImageParam
from
PIL
import
Image
from
pydantic
import
BaseModel
,
ConfigDict
,
TypeAdapter
# yapf: enable
...
...
@@ -38,7 +39,6 @@ from typing_extensions import Required, TypeAlias, TypedDict
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_cls
from
vllm.model_executor.models
import
SupportsMultiModal
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalDataDict
from
vllm.multimodal.utils
import
MediaConnector
...
...
@@ -52,6 +52,12 @@ from vllm.utils import deprecate_kwargs, random_uuid
logger
=
init_logger
(
__name__
)
MODALITY_PLACEHOLDERS_MAP
=
{
"image"
:
"<##IMAGE##>"
,
"audio"
:
"<##AUDIO##>"
,
"video"
:
"<##VIDEO##>"
,
}
class
AudioURL
(
TypedDict
,
total
=
False
):
url
:
Required
[
str
]
...
...
@@ -145,6 +151,27 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
video_url
:
Required
[
str
]
class
CustomThinkCompletionContentParam
(
TypedDict
,
total
=
False
):
"""A Think Completion Content Param that accepts a plain text and a boolean.
Example:
{
"thinking": "I am thinking about the answer",
"closed": True,
"type": "thinking"
}
"""
thinking
:
Required
[
str
]
"""The thinking content."""
closed
:
bool
"""Whether the thinking is closed."""
type
:
Required
[
Literal
[
"thinking"
]]
"""The thinking type."""
ChatCompletionContentPartParam
:
TypeAlias
=
Union
[
OpenAIChatCompletionContentPartParam
,
ChatCompletionContentPartAudioParam
,
ChatCompletionContentPartInputAudioParam
,
...
...
@@ -153,7 +180,8 @@ ChatCompletionContentPartParam: TypeAlias = Union[
CustomChatCompletionContentSimpleImageParam
,
ChatCompletionContentPartImageEmbedsParam
,
CustomChatCompletionContentSimpleAudioParam
,
CustomChatCompletionContentSimpleVideoParam
,
str
]
CustomChatCompletionContentSimpleVideoParam
,
str
,
CustomThinkCompletionContentParam
]
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
...
...
@@ -354,6 +382,7 @@ def resolve_mistral_chat_template(
"so it will be ignored."
)
return
None
@
deprecate_kwargs
(
"trust_remote_code"
,
additional_message
=
"Please use `model_config.trust_remote_code` instead."
,
...
...
@@ -517,6 +546,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
@
cached_property
def
model_cls
(
self
):
from
vllm.model_executor.model_loader
import
get_model_cls
return
get_model_cls
(
self
.
model_config
)
@
property
...
...
@@ -633,15 +663,22 @@ class BaseMultiModalContentParser(ABC):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# multimodal placeholder_string : count
self
.
_placeholder_counts
:
dict
[
str
,
int
]
=
defaultdict
(
lambda
:
0
)
def
_add_placeholder
(
self
,
placeholder
:
Optional
[
str
]):
# stores model placehodlers list with corresponding
# general MM placeholder:
# {
# "<##IMAGE##>": ["<image>", "<image>", "<image>"],
# "<##AUDIO##>": ["<audio>", "<audio>"]
# }
self
.
_placeholder_storage
:
dict
[
str
,
list
]
=
defaultdict
(
list
)
def
_add_placeholder
(
self
,
modality
:
ModalityStr
,
placeholder
:
Optional
[
str
]):
mod_placeholder
=
MODALITY_PLACEHOLDERS_MAP
[
modality
]
if
placeholder
:
self
.
_placeholder_
counts
[
placeholder
]
+=
1
self
.
_placeholder_
storage
[
mod_placeholder
].
append
(
placeholder
)
def
mm_placeholder_
counts
(
self
)
->
dict
[
str
,
in
t
]:
return
dict
(
self
.
_placeholder_
counts
)
def
mm_placeholder_
storage
(
self
)
->
dict
[
str
,
lis
t
]:
return
dict
(
self
.
_placeholder_
storage
)
@
abstractmethod
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
...
...
@@ -685,7 +722,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
image
=
self
.
_connector
.
fetch_image
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_image_embeds
(
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
...
...
@@ -700,17 +737,17 @@ class MultiModalContentParser(BaseMultiModalContentParser):
embedding
=
self
.
_connector
.
fetch_image_embedding
(
image_embeds
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
embedding
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_image_pil
(
self
,
image_pil
:
Image
.
Image
)
->
None
:
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_pil
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio
=
self
.
_connector
.
fetch_audio
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
"audio"
,
placeholder
)
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
audio_data
=
input_audio
.
get
(
"data"
,
""
)
...
...
@@ -723,7 +760,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
video
=
self
.
_connector
.
fetch_video
(
video_url
=
video_url
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
"video"
,
placeholder
)
class
AsyncMultiModalContentParser
(
BaseMultiModalContentParser
):
...
...
@@ -741,7 +778,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
image_coro
=
self
.
_connector
.
fetch_image_async
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_image_embeds
(
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
...
...
@@ -760,20 +797,20 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
future
.
set_result
(
embedding
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
future
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_image_pil
(
self
,
image_pil
:
Image
.
Image
)
->
None
:
future
:
asyncio
.
Future
[
Image
.
Image
]
=
asyncio
.
Future
()
future
.
set_result
(
image_pil
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
future
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio_coro
=
self
.
_connector
.
fetch_audio_async
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
"audio"
,
placeholder
)
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
audio_data
=
input_audio
.
get
(
"data"
,
""
)
...
...
@@ -786,7 +823,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
video
=
self
.
_connector
.
fetch_video_async
(
video_url
=
video_url
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
"video"
,
placeholder
)
def
validate_chat_template
(
chat_template
:
Optional
[
Union
[
Path
,
str
]]):
...
...
@@ -856,12 +893,40 @@ def load_chat_template(
return
_cached_load_chat_template
(
chat_template
,
is_literal
=
is_literal
)
def
_get_interleaved_text_prompt
(
placeholder_storage
:
dict
[
str
,
list
],
texts
:
list
[
str
])
->
str
:
for
idx
,
elem
in
enumerate
(
texts
):
if
elem
in
placeholder_storage
:
texts
[
idx
]
=
placeholder_storage
[
elem
].
pop
(
0
)
return
"
\n
"
.
join
(
texts
)
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def
_get_full_multimodal_text_prompt
(
placeholder_counts
:
dict
[
str
,
int
],
text_prompt
:
str
)
->
str
:
def
_get_full_multimodal_text_prompt
(
placeholder_storage
:
dict
[
str
,
list
],
texts
:
list
[
str
],
interleave_strings
:
bool
)
->
str
:
"""Combine multimodal prompts for a multimodal language model."""
# flatten storage to make it looks like
# {
# "<|image|>": 2,
# "<|audio|>": 1
# }
placeholder_counts
=
Counter
(
[
v
for
elem
in
placeholder_storage
.
values
()
for
v
in
elem
]
)
if
interleave_strings
:
text_prompt
=
_get_interleaved_text_prompt
(
placeholder_storage
,
texts
)
else
:
text_prompt
=
"
\n
"
.
join
(
texts
)
# Pass interleaved text further in case the user used image placeholders
# himself, but forgot to disable the 'interleave_strings' flag
# Look through the text prompt to check for missing placeholders
missing_placeholders
:
list
[
str
]
=
[]
for
placeholder
in
placeholder_counts
:
...
...
@@ -870,6 +935,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
placeholder_counts
[
placeholder
]
-=
text_prompt
.
count
(
placeholder
)
if
placeholder_counts
[
placeholder
]
<
0
:
logger
.
error
(
"Placeholder count is negative! "
"Ensure that the 'interleave_strings' flag is disabled "
"(current value: %s) "
"when manually placing image placeholders."
,
interleave_strings
)
logger
.
debug
(
"Input prompt: %s"
,
text_prompt
)
raise
ValueError
(
f
"Found more '
{
placeholder
}
' placeholders in input prompt than "
"actual multimodal data items."
)
...
...
@@ -877,8 +949,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
missing_placeholders
.
extend
([
placeholder
]
*
placeholder_counts
[
placeholder
])
# NOTE:
For now
we always add missing placeholders
at the front of
# the
p
ro
mpt. This may change to be customizable in the future.
# NOTE:
Default behaviour:
we always add missing placeholders
#
at
the
f
ro
nt of the prompt, if interleave_strings=False
return
"
\n
"
.
join
(
missing_placeholders
+
[
text_prompt
])
...
...
@@ -888,11 +960,14 @@ _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_InputAudioParser
=
partial
(
cast
,
ChatCompletionContentPartInputAudioParam
)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
_PILImageParser
=
partial
(
cast
,
CustomChatCompletionContentPILImageParam
)
_ThinkParser
=
partial
(
cast
,
CustomThinkCompletionContentParam
)
# Need to validate url objects
_ImageParser
=
TypeAdapter
(
ChatCompletionContentPartImageParam
).
validate_python
_AudioParser
=
TypeAdapter
(
ChatCompletionContentPartAudioParam
).
validate_python
_VideoParser
=
TypeAdapter
(
ChatCompletionContentPartVideoParam
).
validate_python
_ResponsesInputImageParser
=
TypeAdapter
(
ResponseInputImageParam
).
validate_python
_ContentPart
:
TypeAlias
=
Union
[
str
,
dict
[
str
,
str
],
InputAudio
,
PILImage
]
# Define a mapping from part types to their corresponding parsing functions.
...
...
@@ -902,6 +977,12 @@ MM_PARSER_MAP: dict[
]
=
{
"text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
None
),
"thinking"
:
lambda
part
:
_ThinkParser
(
part
).
get
(
"thinking"
,
None
),
"input_text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
None
),
"input_image"
:
lambda
part
:
_ResponsesInputImageParser
(
part
).
get
(
"image_url"
,
None
),
"image_url"
:
lambda
part
:
_ImageParser
(
part
).
get
(
"image_url"
,
{}).
get
(
"url"
,
None
),
"image_embeds"
:
...
...
@@ -986,6 +1067,7 @@ def _parse_chat_message_content_parts(
mm_tracker
:
BaseMultiModalItemTracker
,
*
,
wrap_dicts
:
bool
,
interleave_strings
:
bool
,
)
->
list
[
ConversationMessage
]:
content
=
list
[
_ContentPart
]()
...
...
@@ -996,6 +1078,7 @@ def _parse_chat_message_content_parts(
part
,
mm_parser
,
wrap_dicts
=
wrap_dicts
,
interleave_strings
=
interleave_strings
)
if
parse_res
:
content
.
append
(
parse_res
)
...
...
@@ -1005,11 +1088,14 @@ def _parse_chat_message_content_parts(
return
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
# type: ignore
texts
=
cast
(
list
[
str
],
content
)
text_prompt
=
"
\n
"
.
join
(
texts
)
mm_placeholder_counts
=
mm_parser
.
mm_placeholder_counts
()
if
mm_placeholder_counts
:
text_prompt
=
_get_full_multimodal_text_prompt
(
mm_placeholder_counts
,
text_prompt
)
mm_placeholder_storage
=
mm_parser
.
mm_placeholder_storage
()
if
mm_placeholder_storage
:
text_prompt
=
_get_full_multimodal_text_prompt
(
mm_placeholder_storage
,
texts
,
interleave_strings
)
else
:
text_prompt
=
"
\n
"
.
join
(
texts
)
return
[
ConversationMessage
(
role
=
role
,
content
=
text_prompt
)]
...
...
@@ -1018,6 +1104,7 @@ def _parse_chat_message_content_part(
mm_parser
:
BaseMultiModalContentParser
,
*
,
wrap_dicts
:
bool
,
interleave_strings
:
bool
,
)
->
Optional
[
_ContentPart
]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
...
...
@@ -1028,10 +1115,8 @@ def _parse_chat_message_content_part(
"""
if
isinstance
(
part
,
str
):
# Handle plain text parts
return
part
# Handle structured dictionary parts
part_type
,
content
=
_parse_chat_message_content_mm_part
(
part
)
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
# content is None, log a warning and skip
if
part_type
in
VALID_MESSAGE_CONTENT_MM_PART_TYPES
and
content
is
None
:
...
...
@@ -1040,41 +1125,44 @@ def _parse_chat_message_content_part(
"with empty / unparsable content."
,
part
,
part_type
)
return
None
if
part_type
in
(
"text"
,
"
refusal
"
):
if
part_type
in
(
"text"
,
"
input_text"
,
"refusal"
,
"thinking
"
):
str_content
=
cast
(
str
,
content
)
if
wrap_dicts
:
return
{
'type'
:
'text'
,
'text'
:
str_content
}
else
:
return
str_content
modality
=
None
if
part_type
==
"image_pil"
:
image_content
=
cast
(
Image
.
Image
,
content
)
mm_parser
.
parse_image_pil
(
image_content
)
return
{
'type'
:
'image'
}
if
wrap_dicts
else
None
if
part_type
==
"image_url"
:
modality
=
"image"
el
if
part_type
in
(
"image_url"
,
"input_image"
)
:
str_content
=
cast
(
str
,
content
)
mm_parser
.
parse_image
(
str_content
)
return
{
'type'
:
'image'
}
if
wrap_dicts
else
None
if
part_type
==
"image_embeds"
:
modality
=
"image"
el
if
part_type
==
"image_embeds"
:
content
=
cast
(
Union
[
str
,
dict
[
str
,
str
]],
content
)
mm_parser
.
parse_image_embeds
(
content
)
return
{
'type'
:
'image'
}
if
wrap_dicts
else
None
if
part_type
==
"audio_url"
:
modality
=
"image"
el
if
part_type
==
"audio_url"
:
str_content
=
cast
(
str
,
content
)
mm_parser
.
parse_audio
(
str_content
)
return
{
'type'
:
'audio'
}
if
wrap_dicts
else
None
if
part_type
==
"input_audio"
:
modality
=
"audio"
elif
part_type
==
"input_audio"
:
dict_content
=
cast
(
InputAudio
,
content
)
mm_parser
.
parse_input_audio
(
dict_content
)
return
{
'type'
:
'audio'
}
if
wrap_dicts
else
None
if
part_type
==
"video_url"
:
modality
=
"audio"
elif
part_type
==
"video_url"
:
str_content
=
cast
(
str
,
content
)
mm_parser
.
parse_video
(
str_content
)
return
{
'type'
:
'video'
}
if
wrap_dicts
else
None
modality
=
"video"
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
return
{
'type'
:
modality
}
if
wrap_dicts
else
(
MODALITY_PLACEHOLDERS_MAP
[
modality
]
if
interleave_strings
else
None
)
# No need to validate using Pydantic again
...
...
@@ -1086,6 +1174,7 @@ def _parse_chat_message_content(
message
:
ChatCompletionMessageParam
,
mm_tracker
:
BaseMultiModalItemTracker
,
content_format
:
_ChatTemplateContentFormat
,
interleave_strings
:
bool
,
)
->
list
[
ConversationMessage
]:
role
=
message
[
"role"
]
content
=
message
.
get
(
"content"
)
...
...
@@ -1101,6 +1190,7 @@ def _parse_chat_message_content(
content
,
# type: ignore
mm_tracker
,
wrap_dicts
=
(
content_format
==
"openai"
),
interleave_strings
=
interleave_strings
,
)
for
result_msg
in
result
:
...
...
@@ -1153,6 +1243,11 @@ def parse_chat_messages(
msg
,
mm_tracker
,
content_format
,
interleave_strings
=
(
content_format
==
"string"
and
model_config
.
multimodal_config
is
not
None
and
model_config
.
multimodal_config
.
interleave_mm_strings
)
)
conversation
.
extend
(
sub_messages
)
...
...
@@ -1176,6 +1271,11 @@ def parse_chat_messages_futures(
msg
,
mm_tracker
,
content_format
,
interleave_strings
=
(
content_format
==
"string"
and
model_config
.
multimodal_config
is
not
None
and
model_config
.
multimodal_config
.
interleave_mm_strings
)
)
conversation
.
extend
(
sub_messages
)
...
...
vllm/entrypoints/cli/main.py
View file @
711aa9d5
...
...
@@ -7,17 +7,6 @@ to avoid certain eager import breakage.'''
from
__future__
import
annotations
import
importlib.metadata
import
signal
import
sys
def
register_signal_handlers
():
def
signal_handler
(
sig
,
frame
):
sys
.
exit
(
0
)
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
signal
.
signal
(
signal
.
SIGTSTP
,
signal_handler
)
def
main
():
...
...
vllm/entrypoints/cli/openai.py
View file @
711aa9d5
...
...
@@ -55,7 +55,7 @@ def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
try
:
input_message
=
input
(
"> "
)
except
EOFError
:
re
turn
b
re
ak
conversation
.
append
({
"role"
:
"user"
,
"content"
:
input_message
})
chat_completion
=
client
.
chat
.
completions
.
create
(
model
=
model_name
,
...
...
@@ -118,7 +118,7 @@ class ChatCommand(CLISubcommand):
try
:
input_message
=
input
(
"> "
)
except
EOFError
:
re
turn
b
re
ak
conversation
.
append
({
"role"
:
"user"
,
"content"
:
input_message
})
chat_completion
=
client
.
chat
.
completions
.
create
(
...
...
@@ -170,7 +170,10 @@ class CompleteCommand(CLISubcommand):
print
(
"Please enter prompt to complete:"
)
while
True
:
input_prompt
=
input
(
"> "
)
try
:
input_prompt
=
input
(
"> "
)
except
EOFError
:
break
completion
=
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
input_prompt
)
output
=
completion
.
choices
[
0
].
text
...
...
vllm/entrypoints/cli/serve.py
View file @
711aa9d5
...
...
@@ -45,9 +45,6 @@ class ServeSubcommand(CLISubcommand):
if
args
.
headless
or
args
.
api_server_count
<
1
:
run_headless
(
args
)
else
:
if
args
.
data_parallel_start_rank
:
raise
ValueError
(
"data_parallel_start_rank is only "
"applicable in headless mode"
)
if
args
.
api_server_count
>
1
:
run_multi_api_server
(
args
)
else
:
...
...
@@ -65,36 +62,6 @@ class ServeSubcommand(CLISubcommand):
help
=
"Start the vLLM OpenAI Compatible API server."
,
description
=
"Start the vLLM OpenAI Compatible API server."
,
usage
=
"vllm serve [model_tag] [options]"
)
serve_parser
.
add_argument
(
"model_tag"
,
type
=
str
,
nargs
=
'?'
,
help
=
"The model tag to serve "
"(optional if specified in config)"
)
serve_parser
.
add_argument
(
"--headless"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Run in headless mode. See multi-node data parallel "
"documentation for more details."
)
serve_parser
.
add_argument
(
'--data-parallel-start-rank'
,
'-dpr'
,
type
=
int
,
default
=
0
,
help
=
'Starting data parallel rank for secondary nodes.'
)
serve_parser
.
add_argument
(
'--api-server-count'
,
'-asc'
,
type
=
int
,
default
=
1
,
help
=
'How many API server processes to run.'
)
serve_parser
.
add_argument
(
"--config"
,
type
=
str
,
default
=
''
,
required
=
False
,
help
=
"Read CLI options from a config file. "
"Must be a YAML with the following options: "
"https://docs.vllm.ai/en/latest/configuration/serve_args.html"
)
serve_parser
=
make_arg_parser
(
serve_parser
)
show_filtered_argument_or_group_from_help
(
serve_parser
,
[
"serve"
])
...
...
@@ -114,13 +81,14 @@ def run_headless(args: argparse.Namespace):
# Create the EngineConfig.
engine_args
=
vllm
.
AsyncEngineArgs
.
from_cli_args
(
args
)
usage_context
=
UsageContext
.
OPENAI_API_SERVER
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
usage_context
)
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
usage_context
,
headless
=
True
)
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Headless mode is only supported for V1"
)
if
engine_args
.
data_parallel_
rank
is
not
None
:
raise
ValueError
(
"data_parallel_
rank
is not applicable in "
if
engine_args
.
data_parallel_
hybrid_lb
:
raise
ValueError
(
"data_parallel_
hybrid_lb
is not applicable in "
"headless mode"
)
parallel_config
=
vllm_config
.
parallel_config
...
...
@@ -150,7 +118,7 @@ def run_headless(args: argparse.Namespace):
engine_manager
=
CoreEngineProcManager
(
target_fn
=
EngineCoreProc
.
run_engine_core
,
local_engine_count
=
local_engine_count
,
start_index
=
args
.
data_parallel_
start_
rank
,
start_index
=
vllm_config
.
parallel_config
.
data_parallel_rank
,
local_start_index
=
0
,
vllm_config
=
vllm_config
,
local_client
=
False
,
...
...
@@ -197,6 +165,11 @@ def run_multi_api_server(args: argparse.Namespace):
" api_server_count > 1"
)
model_config
.
disable_mm_preprocessor_cache
=
True
if
vllm_config
.
parallel_config
.
data_parallel_hybrid_lb
:
raise
NotImplementedError
(
"Hybrid load balancing with --api-server-count > 0"
"is not yet supported."
)
executor_class
=
Executor
.
get_class
(
vllm_config
)
log_stats
=
not
engine_args
.
disable_log_stats
...
...
Prev
1
…
22
23
24
25
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment