Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2dbe8c07
"vllm/tool_parsers/step3_tool_parser.py" did not exist on "34a984274eae2f8fb9d1d6413abd08d7fcde741c"
Unverified
Commit
2dbe8c07
authored
May 30, 2025
by
Nick Hill
Committed by
GitHub
May 30, 2025
Browse files
[Perf] API-server scaleout with many-to-many server-engine comms (#17546)
parent
84ec470f
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
818 additions
and
353 deletions
+818
-353
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+166
-87
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+244
-212
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+45
-34
vllm/v1/metrics/prometheus.py
vllm/v1/metrics/prometheus.py
+77
-0
vllm/v1/request.py
vllm/v1/request.py
+3
-2
vllm/v1/utils.py
vllm/v1/utils.py
+283
-18
No files found.
vllm/v1/engine/core.py
View file @
2dbe8c07
...
...
@@ -7,6 +7,7 @@ import threading
import
time
from
collections
import
deque
from
concurrent.futures
import
Future
from
contextlib
import
ExitStack
from
inspect
import
isclass
,
signature
from
logging
import
DEBUG
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
...
...
@@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
from
vllm.utils
import
make_zmq_socket
,
resolve_obj_by_qualname
,
zmq_socket_ctx
from
vllm.utils
import
make_zmq_socket
,
resolve_obj_by_qualname
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
unify_kv_cache_configs
)
from
vllm.v1.core.sched.interface
import
SchedulerInterface
...
...
@@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
from
vllm.v1.engine.mm_input_cache
import
MirroredProcessingCache
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.utils
import
EngineHandshakeMetadata
,
EngineZmqAddresses
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
...
...
@@ -211,7 +214,7 @@ class EngineCore:
# Re-raise exception
raise
err
def
step
(
self
)
->
tuple
[
EngineCoreOutputs
,
bool
]:
def
step
(
self
)
->
tuple
[
dict
[
int
,
EngineCoreOutputs
]
,
bool
]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
...
...
@@ -221,10 +224,7 @@ class EngineCore:
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if
not
self
.
scheduler
.
has_requests
():
return
EngineCoreOutputs
(
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
(),
),
False
return
{},
False
scheduler_output
=
self
.
scheduler
.
schedule
()
model_output
=
self
.
execute_model
(
scheduler_output
)
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
...
...
@@ -234,7 +234,7 @@ class EngineCore:
scheduler_output
.
total_num_scheduled_tokens
>
0
)
def
step_with_batch_queue
(
self
)
->
tuple
[
Optional
[
EngineCoreOutputs
],
bool
]:
self
)
->
tuple
[
Optional
[
dict
[
int
,
EngineCoreOutputs
]
]
,
bool
]:
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned.
...
...
@@ -276,8 +276,8 @@ class EngineCore:
# Blocking until the first result is available.
model_output
=
future
.
result
()
self
.
batch_queue
.
task_done
()
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
engine_core_outputs
=
(
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
)
return
engine_core_outputs
,
scheduled_batch
...
...
@@ -362,7 +362,7 @@ class EngineCoreProc(EngineCore):
self
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
input
_address
:
str
,
handshake
_address
:
str
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
engine_index
:
int
=
0
,
...
...
@@ -375,65 +375,70 @@ class EngineCoreProc(EngineCore):
# Create input socket.
input_ctx
=
zmq
.
Context
()
identity
=
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
input_socket
=
make_zmq_socket
(
input_ctx
,
input_address
,
zmq
.
DEALER
,
identity
=
identity
,
bind
=
False
)
try
:
with
make_zmq_socket
(
input_ctx
,
handshake_address
,
zmq
.
DEALER
,
identity
=
identity
,
linger
=
5000
,
bind
=
False
)
as
handshake_socket
:
# Register engine with front-end.
output_address
=
self
.
startup_handshake
(
input_socket
,
on_head_node
,
vllm_config
.
parallel_config
)
addresses
=
self
.
startup_handshake
(
handshake_socket
,
on_head_node
,
vllm_config
.
parallel_config
)
self
.
client_count
=
len
(
addresses
.
outputs
)
# Update config which may have changed from the handshake.
vllm_config
.
__post_init__
()
# Set up data parallel environment.
self
.
has_coordinator
=
addresses
.
coordinator_output
is
not
None
self
.
_init_data_parallel
(
vllm_config
)
# Initialize engine core and model.
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
executor_fail_callback
)
self
.
engine_index
=
engine_index
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
self
.
engines_running
=
False
self
.
last_counts
=
(
0
,
0
)
# Send ready message.
num_gpu_blocks
=
vllm_config
.
cache_config
.
num_gpu_blocks
input
_socket
.
send
(
handshake
_socket
.
send
(
msgspec
.
msgpack
.
encode
({
"status"
:
"READY"
,
"local"
:
on_head_node
,
"num_gpu_blocks"
:
num_gpu_blocks
,
}))
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self
.
input_queue
=
input_queue
self
.
output_queue
=
queue
.
Queue
[
Union
[
EngineCoreOutputs
,
bytes
]]()
threading
.
Thread
(
target
=
self
.
process_input_socket
,
args
=
(
input_socket
,
),
daemon
=
True
).
start
()
input_socket
=
None
self
.
output_thread
=
threading
.
Thread
(
target
=
self
.
process_output_socket
,
args
=
(
output_address
,
engine_index
),
daemon
=
True
)
self
.
output_thread
.
start
()
finally
:
if
input_socket
is
not
None
:
input_socket
.
close
(
linger
=
0
)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self
.
input_queue
=
input_queue
self
.
output_queue
=
queue
.
Queue
[
Union
[
tuple
[
int
,
EngineCoreOutputs
],
bytes
]]()
threading
.
Thread
(
target
=
self
.
process_input_sockets
,
args
=
(
addresses
.
inputs
,
addresses
.
coordinator_input
,
identity
),
daemon
=
True
).
start
()
self
.
output_thread
=
threading
.
Thread
(
target
=
self
.
process_output_sockets
,
args
=
(
addresses
.
outputs
,
addresses
.
coordinator_output
,
engine_index
),
daemon
=
True
)
self
.
output_thread
.
start
()
@
staticmethod
def
startup_handshake
(
input_socket
:
zmq
.
Socket
,
on_head_node
:
bool
,
parallel_config
:
ParallelConfig
)
->
str
:
def
startup_handshake
(
handshake_socket
:
zmq
.
Socket
,
on_head_node
:
bool
,
parallel_config
:
ParallelConfig
)
->
EngineZmqAddresses
:
# Send registration message.
input
_socket
.
send
(
handshake
_socket
.
send
(
msgspec
.
msgpack
.
encode
({
"status"
:
"HELLO"
,
"local"
:
on_head_node
,
...
...
@@ -441,22 +446,20 @@ class EngineCoreProc(EngineCore):
# Receive initialization message.
logger
.
info
(
"Waiting for init message from front-end."
)
if
not
input
_socket
.
poll
(
timeout
=
HANDSHAKE_TIMEOUT_MINS
*
60
*
1
000
):
if
not
handshake
_socket
.
poll
(
timeout
=
HANDSHAKE_TIMEOUT_MINS
*
60
_
000
):
raise
RuntimeError
(
"Did not receive response from front-end "
f
"process within
{
HANDSHAKE_TIMEOUT_MINS
}
"
f
"minutes"
)
init_bytes
=
input_socket
.
recv
()
init_message
=
msgspec
.
msgpack
.
decode
(
init_bytes
)
init_bytes
=
handshake_socket
.
recv
()
init_message
:
EngineHandshakeMetadata
=
msgspec
.
msgpack
.
decode
(
init_bytes
,
type
=
EngineHandshakeMetadata
)
logger
.
debug
(
"Received init message: %s"
,
init_message
)
output_socket_address
=
init_message
[
"output_socket_address"
]
#TBD(nick) maybe replace IP with configured head node address
received_parallel_config
=
init_message
[
"parallel_config"
]
received_parallel_config
=
init_message
.
parallel_config
for
key
,
value
in
received_parallel_config
.
items
():
setattr
(
parallel_config
,
key
,
value
)
return
output_socket_
address
return
init_message
.
address
es
@
staticmethod
def
run_engine_core
(
*
args
,
...
...
@@ -528,7 +531,7 @@ class EngineCoreProc(EngineCore):
"""Exits when an engine step needs to be performed."""
waited
=
False
while
not
self
.
engines_running
and
not
(
self
.
scheduler
.
has_requests
()
)
:
while
not
self
.
engines_running
and
not
self
.
scheduler
.
has_requests
():
if
logger
.
isEnabledFor
(
DEBUG
)
and
self
.
input_queue
.
empty
():
logger
.
debug
(
"EngineCore waiting for work."
)
waited
=
True
...
...
@@ -549,8 +552,8 @@ class EngineCoreProc(EngineCore):
# Step the engine core.
outputs
,
model_executed
=
self
.
step_fn
()
# Put EngineCoreOutputs into the output queue.
i
f
output
s
i
s
not
None
:
self
.
output_queue
.
put_nowait
(
output
s
)
f
or
output
i
n
(
outputs
.
items
()
if
outputs
else
())
:
self
.
output_queue
.
put_nowait
(
output
)
return
model_executed
...
...
@@ -563,7 +566,7 @@ class EngineCoreProc(EngineCore):
elif
request_type
==
EngineCoreRequestType
.
ABORT
:
self
.
abort_requests
(
request
)
elif
request_type
==
EngineCoreRequestType
.
UTILITY
:
call_id
,
method_name
,
args
=
request
client_idx
,
call_id
,
method_name
,
args
=
request
output
=
UtilityOutput
(
call_id
)
try
:
method
=
getattr
(
self
,
method_name
)
...
...
@@ -574,7 +577,7 @@ class EngineCoreProc(EngineCore):
output
.
failure_message
=
(
f
"Call to
{
method_name
}
method"
f
" failed:
{
str
(
e
)
}
"
)
self
.
output_queue
.
put_nowait
(
EngineCoreOutputs
(
utility_output
=
output
))
(
client_idx
,
EngineCoreOutputs
(
utility_output
=
output
))
)
elif
request_type
==
EngineCoreRequestType
.
EXECUTOR_FAILED
:
raise
RuntimeError
(
"Executor failed."
)
else
:
...
...
@@ -607,27 +610,68 @@ class EngineCoreProc(EngineCore):
logger
.
fatal
(
"vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue."
)
def
process_input_socket
(
self
,
input_socket
:
zmq
.
Socket
):
def
process_input_sockets
(
self
,
input_addresses
:
list
[
str
],
coord_input_address
:
Optional
[
str
],
identity
:
bytes
):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder
=
MsgpackDecoder
(
EngineCoreRequest
)
generic_decoder
=
MsgpackDecoder
()
while
True
:
# (RequestType, RequestData)
type_frame
,
*
data_frames
=
input_socket
.
recv_multipart
(
copy
=
False
)
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
decoder
=
add_request_decoder
if
(
request_type
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
request
=
decoder
.
decode
(
data_frames
)
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
with
ExitStack
()
as
stack
,
zmq
.
Context
()
as
ctx
:
input_sockets
=
[
stack
.
enter_context
(
make_zmq_socket
(
ctx
,
input_address
,
zmq
.
DEALER
,
identity
=
identity
,
bind
=
False
))
for
input_address
in
input_addresses
]
if
coord_input_address
is
None
:
coord_socket
=
None
else
:
coord_socket
=
stack
.
enter_context
(
make_zmq_socket
(
ctx
,
coord_input_address
,
zmq
.
XSUB
,
identity
=
identity
,
bind
=
False
))
# Send subscription message to coordinator.
coord_socket
.
send
(
b
'
\x01
'
)
# Register sockets with poller.
poller
=
zmq
.
Poller
()
for
input_socket
in
input_sockets
:
# Send initial message to each input socket - this is required
# before the front-end ROUTER socket can send input messages
# back to us.
input_socket
.
send
(
b
''
)
poller
.
register
(
input_socket
,
zmq
.
POLLIN
)
if
coord_socket
is
not
None
:
poller
.
register
(
coord_socket
,
zmq
.
POLLIN
)
def
process_output_socket
(
self
,
output_path
:
str
,
engine_index
:
int
):
while
True
:
for
input_socket
,
_
in
poller
.
poll
():
# (RequestType, RequestData)
type_frame
,
*
data_frames
=
input_socket
.
recv_multipart
(
copy
=
False
)
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
decoder
=
add_request_decoder
if
(
request_type
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
request
=
decoder
.
decode
(
data_frames
)
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
def
process_output_sockets
(
self
,
output_paths
:
list
[
str
],
coord_output_path
:
Optional
[
str
],
engine_index
:
int
):
"""Output socket IO thread."""
# Msgpack serialization encoding.
...
...
@@ -641,30 +685,49 @@ class EngineCoreProc(EngineCore):
# We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket.
with
zmq_socket_ctx
(
output_path
,
zmq
.
constants
.
PUSH
,
linger
=
4000
)
as
socket
:
with
ExitStack
()
as
stack
,
zmq
.
Context
()
as
ctx
:
sockets
=
[
stack
.
enter_context
(
make_zmq_socket
(
ctx
,
output_path
,
zmq
.
PUSH
,
linger
=
4000
))
for
output_path
in
output_paths
]
coord_socket
=
stack
.
enter_context
(
make_zmq_socket
(
ctx
,
coord_output_path
,
zmq
.
PUSH
,
bind
=
False
,
linger
=
4000
))
if
coord_output_path
is
not
None
else
None
max_reuse_bufs
=
len
(
sockets
)
+
1
while
True
:
outputs
=
self
.
output_queue
.
get
()
if
outputs
==
EngineCoreProc
.
ENGINE_CORE_DEAD
:
socket
.
send
(
outputs
,
copy
=
False
)
output
=
self
.
output_queue
.
get
()
if
output
==
EngineCoreProc
.
ENGINE_CORE_DEAD
:
for
socket
in
sockets
:
socket
.
send
(
output
)
break
assert
not
isinstance
(
outputs
,
bytes
)
assert
not
isinstance
(
output
,
bytes
)
client_index
,
outputs
=
output
outputs
.
engine_index
=
engine_index
if
client_index
==
-
1
:
# Don't reuse buffer for coordinator message
# which will be very small.
assert
coord_socket
is
not
None
coord_socket
.
send_multipart
(
encoder
.
encode
(
outputs
))
continue
# Reclaim buffers that zmq is finished with.
while
pending
and
pending
[
-
1
][
0
].
done
:
reuse_buffers
.
append
(
pending
.
pop
()[
2
])
buffer
=
reuse_buffers
.
pop
()
if
reuse_buffers
else
bytearray
()
buffers
=
encoder
.
encode_into
(
outputs
,
buffer
)
tracker
=
socket
.
send_multipart
(
buffers
,
copy
=
False
,
track
=
True
)
tracker
=
socket
s
[
client_index
]
.
send_multipart
(
buffers
,
copy
=
False
,
track
=
True
)
if
not
tracker
.
done
:
ref
=
outputs
if
len
(
buffers
)
>
1
else
None
pending
.
appendleft
((
tracker
,
ref
,
buffer
))
elif
len
(
reuse_buffers
)
<
2
:
#
Keep at most 2
buffers to reuse.
elif
len
(
reuse_buffers
)
<
max_reuse_bufs
:
#
Limit the number of
buffers to reuse.
reuse_buffers
.
append
(
buffer
)
...
...
@@ -676,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc):
self
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
input
_address
:
str
,
handshake
_address
:
str
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
):
...
...
@@ -691,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc):
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self
.
counter
=
0
self
.
current_wave
=
0
# Initialize the engine.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
super
().
__init__
(
vllm_config
,
on_head_node
,
input
_address
,
super
().
__init__
(
vllm_config
,
on_head_node
,
handshake
_address
,
executor_class
,
log_stats
,
dp_rank
)
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
...
...
@@ -726,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc):
self
.
dp_rank
=
dp_rank
self
.
dp_group
=
vllm_config
.
parallel_config
.
stateless_init_dp_group
()
self
.
current_wave
=
0
def
shutdown
(
self
):
super
().
shutdown
()
...
...
@@ -734,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
def
add_request
(
self
,
request
:
EngineCoreRequest
):
if
request
.
current_wave
!=
self
.
current_wave
:
if
self
.
has_coordinator
and
request
.
current_wave
!=
self
.
current_wave
:
if
request
.
current_wave
>
self
.
current_wave
:
self
.
current_wave
=
request
.
current_wave
elif
not
self
.
engines_running
:
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self
.
output_queue
.
put_nowait
(
EngineCoreOutputs
(
start_wave
=
self
.
current_wave
))
(
-
1
,
EngineCoreOutputs
(
start_wave
=
self
.
current_wave
))
)
super
().
add_request
(
request
)
def
_handle_client_request
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
if
request_type
==
EngineCoreRequestType
.
START_DP_WAVE
:
new_wave
:
int
=
request
if
new_wave
>=
self
.
current_wave
:
new_wave
,
exclude_eng_index
=
request
if
exclude_eng_index
!=
self
.
engine_index
and
(
new_wave
>=
self
.
current_wave
):
self
.
current_wave
=
new_wave
if
not
self
.
engines_running
:
logger
.
debug
(
"EngineCore starting idle loop for wave %d."
,
...
...
@@ -758,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc):
else
:
super
().
_handle_client_request
(
request_type
,
request
)
def
_maybe_publish_request_counts
(
self
):
if
not
self
.
has_coordinator
:
return
# Publish our request counts (if they've changed).
counts
=
self
.
scheduler
.
get_request_counts
()
if
counts
!=
self
.
last_counts
:
self
.
last_counts
=
counts
stats
=
SchedulerStats
(
*
counts
)
self
.
output_queue
.
put_nowait
(
(
-
1
,
EngineCoreOutputs
(
scheduler_stats
=
stats
)))
def
run_busy_loop
(
self
):
"""Core busy loop of the EngineCore for data parallel case."""
...
...
@@ -768,6 +844,8 @@ class DPEngineCoreProc(EngineCoreProc):
# 2) Step the engine core.
executed
=
self
.
_process_engine_step
()
self
.
_maybe_publish_request_counts
()
local_unfinished_reqs
=
self
.
scheduler
.
has_unfinished_requests
()
if
not
executed
:
if
not
local_unfinished_reqs
and
not
self
.
engines_running
:
...
...
@@ -788,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc):
logger
.
debug
(
"Wave %d finished, pausing engine loop."
,
self
.
current_wave
)
self
.
output_queue
.
put_nowait
(
EngineCoreOutputs
(
wave_complete
=
self
.
current_wave
))
(
-
1
,
EngineCoreOutputs
(
wave_complete
=
self
.
current_wave
)))
self
.
current_wave
+=
1
def
_has_global_unfinished_reqs
(
self
,
local_unfinished
:
bool
)
->
bool
:
...
...
vllm/v1/engine/core_client.py
View file @
2dbe8c07
...
...
@@ -2,6 +2,7 @@
import
asyncio
import
contextlib
import
queue
import
sys
import
uuid
import
weakref
from
abc
import
ABC
,
abstractmethod
...
...
@@ -9,26 +10,28 @@ from collections import deque
from
collections.abc
import
Awaitable
,
Sequence
from
concurrent.futures
import
Future
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
threading
import
Thread
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
msgspec
import
msgspec
.msgpack
import
zmq
import
zmq.asyncio
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.utils
import
(
get_open_port
,
get_open_zmq_inproc_path
,
get_open_zmq_ipc_path
,
get_tcp_uri
,
make_
zmq_socket
)
from
vllm.utils
import
(
get_open_zmq_inproc_path
,
make_zmq_socket
,
zmq_socket
_ctx
)
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.coordinator
import
DPCoordinator
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.utils
import
CoreEngineProcManager
from
vllm.v1.utils
import
(
CoreEngine
,
CoreEngineProcManager
,
EngineZmqAddresses
,
get_engine_client_zmq_addr
,
wait_for_engine_startup
)
logger
=
init_logger
(
__name__
)
...
...
@@ -36,8 +39,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
STARTUP_POLL_PERIOD_MS
=
10000
class
EngineCoreClient
(
ABC
):
"""
...
...
@@ -207,7 +208,7 @@ class InprocClient(EngineCoreClient):
def
get_output
(
self
)
->
EngineCoreOutputs
:
outputs
,
_
=
self
.
engine_core
.
step
()
return
outputs
return
outputs
.
get
(
0
)
or
EngineCoreOutputs
()
def
add_request
(
self
,
request
:
EngineCoreRequest
)
->
None
:
self
.
engine_core
.
add_request
(
request
)
...
...
@@ -266,24 +267,6 @@ class InprocClient(EngineCoreClient):
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
CoreEngineState
(
Enum
):
NEW
=
auto
()
CONNECTED
=
auto
()
READY
=
auto
()
class
CoreEngine
:
"""One per data parallel rank."""
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
self
.
local
=
local
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
self
.
state
=
CoreEngineState
.
NEW
self
.
num_reqs_in_flight
=
0
@
dataclass
class
BackgroundResources
:
"""Used as a finalizer for clean shutdown, avoiding
...
...
@@ -291,9 +274,12 @@ class BackgroundResources:
ctx
:
Union
[
zmq
.
Context
]
local_engine_manager
:
Optional
[
CoreEngineProcManager
]
=
None
coordinator
:
Optional
[
DPCoordinator
]
=
None
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
first_req_send_socket
:
Optional
[
zmq
.
asyncio
.
Socket
]
=
None
output_queue_task
:
Optional
[
asyncio
.
Task
]
=
None
stats_update_task
:
Optional
[
asyncio
.
Task
]
=
None
shutdown_path
:
Optional
[
str
]
=
None
# Set if any of the engines are dead. Here so that the output
...
...
@@ -306,16 +292,21 @@ class BackgroundResources:
self
.
engine_dead
=
True
if
self
.
local_engine_manager
is
not
None
:
self
.
local_engine_manager
.
close
()
if
self
.
coordinator
is
not
None
:
self
.
coordinator
.
close
()
if
self
.
output_queue_task
is
not
None
:
self
.
output_queue_task
.
cancel
()
if
self
.
stats_update_task
is
not
None
:
self
.
stats_update_task
.
cancel
()
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
if
self
.
output_socket
is
not
None
:
self
.
output_socket
.
close
(
linger
=
0
)
if
self
.
input_socket
is
not
None
:
self
.
input_socket
.
close
(
linger
=
0
)
for
socket
in
(
self
.
output_socket
,
self
.
input_socket
,
self
.
first_req_send_socket
):
if
socket
is
not
None
:
socket
.
close
(
linger
=
0
)
if
self
.
shutdown_path
is
not
None
:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
...
...
@@ -350,6 +341,7 @@ class MPClient(EngineCoreClient):
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
):
self
.
vllm_config
=
vllm_config
# Serialization setup.
...
...
@@ -369,8 +361,8 @@ class MPClient(EngineCoreClient):
try
:
parallel_config
=
vllm_config
.
parallel_config
local_engine_count
=
parallel_config
.
data_parallel_size_local
start_index
=
parallel_config
.
data_parallel_rank
local_start_index
=
parallel_config
.
data_parallel_rank_local
dp_size
=
parallel_config
.
data_parallel_size
# SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see
...
...
@@ -382,42 +374,53 @@ class MPClient(EngineCoreClient):
CoreEngine
(
index
=
local_start_index
,
local
=
True
)
]
else
:
assert
start_index
==
0
assert
parallel_config
.
data_parallel_rank
==
0
local_start_index
=
0
self
.
core_engines
=
[
CoreEngine
(
index
=
i
,
local
=
(
i
<
local_engine_count
))
for
i
in
range
(
p
arallel_config
.
data_parallel
_size
)
for
i
in
range
(
d
p_size
)
]
input_address
,
output_address
=
self
.
_get_zmq_addresses
(
parallel_config
,
spmd_mode
)
local_only
=
spmd_mode
or
local_engine_count
==
dp_size
self
.
stats_update_address
:
Optional
[
str
]
=
None
if
client_addresses
is
not
None
:
input_address
=
client_addresses
[
"input_address"
]
output_address
=
client_addresses
[
"output_address"
]
self
.
stats_update_address
=
client_addresses
.
get
(
"stats_update_address"
)
else
:
host
=
parallel_config
.
data_parallel_master_ip
input_address
=
get_engine_client_zmq_addr
(
local_only
,
host
)
output_address
=
get_engine_client_zmq_addr
(
local_only
,
host
)
# Create input and output sockets.
self
.
input_socket
=
self
.
resources
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
input_address
,
zmq
.
ROUTER
,
bind
=
True
)
self
.
resources
.
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_address
,
zmq
.
constants
.
PULL
)
# Start local engines.
if
local_engine_count
:
# In server mode, start_index and local_start_index will
# both be 0.
self
.
resources
.
local_engine_manager
=
CoreEngineProcManager
(
EngineCoreProc
.
run_engine_core
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
input_address
=
input_address
,
on_head_node
=
True
,
local_engine_count
=
local_engine_count
,
start_index
=
start_index
,
local_start_index
=
local_start_index
)
self
.
ctx
,
output_address
,
zmq
.
PULL
)
if
client_addresses
is
None
:
self
.
_init_engines_direct
(
vllm_config
,
local_only
,
local_start_index
,
input_address
,
output_address
,
executor_class
,
log_stats
)
coordinator
=
self
.
resources
.
coordinator
if
coordinator
:
self
.
stats_update_address
=
(
coordinator
.
get_stats_publish_address
())
# Wait for ready messages from each engine on the input socket.
identities
=
set
(
e
.
identity
for
e
in
self
.
core_engines
)
sync_input_socket
=
zmq
.
Socket
.
shadow
(
self
.
input_socket
)
while
identities
:
if
not
sync_input_socket
.
poll
(
timeout
=
600_000
):
raise
TimeoutError
(
"Timed out waiting for engines to send"
"initial message on input socket."
)
identity
,
_
=
sync_input_socket
.
recv_multipart
()
identities
.
remove
(
identity
)
self
.
core_engine
=
self
.
core_engines
[
0
]
# Wait for engine core process(es) to start.
self
.
_wait_for_engine_startup
(
output_address
,
parallel_config
)
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
# Request objects which may contain pytorch-allocated tensors
...
...
@@ -430,116 +433,67 @@ class MPClient(EngineCoreClient):
if
not
success
:
self
.
_finalizer
()
@
staticmethod
def
_get_zmq_addresses
(
parallel_config
:
ParallelConfig
,
spmd_mode
:
bool
)
->
tuple
[
str
,
str
]:
"""Returns (input_address, output_address)."""
dp_size
=
parallel_config
.
data_parallel_size
def
_init_engines_direct
(
self
,
vllm_config
:
VllmConfig
,
local_only
:
bool
,
local_start_index
:
int
,
input_address
:
str
,
output_address
:
str
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
):
"""Self-contained client mode, launch engine and coordinator process
as needed."""
parallel_config
=
vllm_config
.
parallel_config
local_engine_count
=
parallel_config
.
data_parallel_size_local
start_index
=
parallel_config
.
data_parallel_rank
host
=
parallel_config
.
data_parallel_master_ip
if
local_engine_count
==
dp_size
or
spmd_mode
:
input_address
=
get_open_zmq_ipc_path
()
output_address
=
get_open_zmq_ipc_path
()
else
:
host
=
parallel_config
.
data_parallel_master_ip
input_port
=
parallel_config
.
data_parallel_rpc_port
output_port
=
get_open_port
()
input_address
=
get_tcp_uri
(
host
,
input_port
)
output_address
=
get_tcp_uri
(
host
,
output_port
)
return
input_address
,
output_address
def
_wait_for_engine_startup
(
self
,
output_address
:
str
,
parallel_config
:
ParallelConfig
):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket
=
zmq
.
Socket
.
shadow
(
self
.
input_socket
)
# Wait for engine core process(es) to send ready messages.
local_count
=
parallel_config
.
data_parallel_size_local
remote_count
=
len
(
self
.
core_engines
)
-
local_count
# [local, remote] counts
conn_pending
,
start_pending
=
[
local_count
,
remote_count
],
[
0
,
0
]
poller
=
zmq
.
Poller
()
poller
.
register
(
sync_input_socket
,
zmq
.
POLLIN
)
proc_manager
=
self
.
resources
.
local_engine_manager
if
proc_manager
is
not
None
:
for
sentinel
in
proc_manager
.
sentinels
():
poller
.
register
(
sentinel
,
zmq
.
POLLIN
)
while
any
(
conn_pending
)
or
any
(
start_pending
):
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
if
not
events
:
if
any
(
conn_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect."
,
*
conn_pending
)
if
any
(
start_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to start."
,
*
start_pending
)
continue
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
sync_input_socket
:
# One of the local core processes exited.
finished
=
proc_manager
.
finished_procs
(
)
if
proc_manager
else
{}
raise
RuntimeError
(
"Engine core initialization failed. "
"See root cause above. "
f
"Failed core proc(s):
{
finished
}
"
)
# Receive HELLO and READY messages from the input socket.
eng_identity
,
ready_msg_bytes
=
sync_input_socket
.
recv_multipart
()
eng_index
=
int
.
from_bytes
(
eng_identity
,
byteorder
=
"little"
)
engine
=
next
(
(
e
for
e
in
self
.
core_engines
if
e
.
identity
==
eng_identity
),
None
)
if
engine
is
None
:
raise
RuntimeError
(
f
"Message from engine with unexpected data "
f
"parallel rank:
{
eng_index
}
"
)
msg
=
msgspec
.
msgpack
.
decode
(
ready_msg_bytes
)
status
,
local
=
msg
[
"status"
],
msg
[
"local"
]
if
local
!=
engine
.
local
:
raise
RuntimeError
(
f
"
{
status
}
message from "
f
"
{
'local'
if
local
else
'remote'
}
"
f
"engine
{
eng_index
}
, expected it to be "
f
"
{
'local'
if
engine
.
local
else
'remote'
}
"
)
if
status
==
"HELLO"
and
engine
.
state
==
CoreEngineState
.
NEW
:
# Send init message with DP config info.
init_message
=
self
.
encoder
.
encode
({
"output_socket_address"
:
output_address
,
"parallel_config"
:
{
"data_parallel_master_ip"
:
parallel_config
.
data_parallel_master_ip
,
"data_parallel_master_port"
:
parallel_config
.
data_parallel_master_port
,
"data_parallel_size"
:
parallel_config
.
data_parallel_size
,
},
})
sync_input_socket
.
send_multipart
((
eng_identity
,
*
init_message
),
copy
=
False
)
conn_pending
[
0
if
local
else
1
]
-=
1
start_pending
[
0
if
local
else
1
]
+=
1
engine
.
state
=
CoreEngineState
.
CONNECTED
elif
status
==
"READY"
and
(
engine
.
state
==
CoreEngineState
.
CONNECTED
):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
cache_config
=
self
.
vllm_config
.
cache_config
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
or
0
num_gpu_blocks
+=
msg
[
'num_gpu_blocks'
]
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
start_pending
[
0
if
local
else
1
]
-=
1
engine
.
state
=
CoreEngineState
.
READY
else
:
raise
RuntimeError
(
f
"Unexpected
{
status
}
message for "
f
"
{
'local'
if
local
else
'remote'
}
engine "
f
"
{
eng_index
}
in
{
engine
.
state
}
state."
)
if
len
(
self
.
core_engines
)
>
1
:
self
.
resources
.
coordinator
=
DPCoordinator
(
parallel_config
)
handshake_address
=
get_engine_client_zmq_addr
(
local_only
,
host
,
parallel_config
.
data_parallel_rpc_port
)
logger
.
debug
(
"%s from %s core engine process %s."
,
status
,
"local"
if
local
else
"remote"
,
eng_index
)
with
zmq_socket_ctx
(
handshake_address
,
zmq
.
ROUTER
,
bind
=
True
)
as
handshake_socket
:
# Start local engines.
if
local_engine_count
:
# In server mode, start_index and local_start_index will
# both be 0.
self
.
resources
.
local_engine_manager
=
CoreEngineProcManager
(
EngineCoreProc
.
run_engine_core
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
handshake_address
=
handshake_address
,
on_head_node
=
True
,
local_engine_count
=
local_engine_count
,
start_index
=
start_index
,
local_start_index
=
local_start_index
)
# Wait for engine core process(es) to start.
self
.
_wait_for_engine_startup
(
handshake_socket
,
input_address
,
output_address
)
def
_wait_for_engine_startup
(
self
,
handshake_socket
:
zmq
.
Socket
,
input_address
:
str
,
output_address
:
str
):
addresses
=
EngineZmqAddresses
(
inputs
=
[
input_address
],
outputs
=
[
output_address
],
)
coordinator
=
self
.
resources
.
coordinator
if
coordinator
is
not
None
:
addresses
.
coordinator_input
,
addresses
.
coordinator_output
=
(
coordinator
.
get_engine_socket_addresses
())
wait_for_engine_startup
(
handshake_socket
,
addresses
,
self
.
core_engines
,
self
.
vllm_config
.
parallel_config
,
self
.
vllm_config
.
cache_config
,
self
.
resources
.
local_engine_manager
,
coordinator
.
proc
if
coordinator
else
None
,
)
def
shutdown
(
self
):
# Terminate background resources.
...
...
@@ -605,8 +559,8 @@ class SyncMPClient(MPClient):
try
:
shutdown_socket
.
bind
(
shutdown_path
)
poller
=
zmq
.
Poller
()
poller
.
register
(
shutdown_socket
)
poller
.
register
(
out_socket
)
poller
.
register
(
shutdown_socket
,
zmq
.
POLLIN
)
poller
.
register
(
out_socket
,
zmq
.
POLLIN
)
while
True
:
socks
=
poller
.
poll
()
if
not
socks
:
...
...
@@ -668,7 +622,7 @@ class SyncMPClient(MPClient):
future
:
Future
[
Any
]
=
Future
()
self
.
utility_results
[
call_id
]
=
future
self
.
_send_input
(
EngineCoreRequestType
.
UTILITY
,
(
call_id
,
method
,
args
))
(
0
,
call_id
,
method
,
args
))
return
future
.
result
()
...
...
@@ -730,15 +684,21 @@ class SyncMPClient(MPClient):
class
AsyncMPClient
(
MPClient
):
"""Asyncio-compatible client for multi-proc EngineCore."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
):
super
().
__init__
(
asyncio_mode
=
True
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
client_addresses
=
client_addresses
,
)
self
.
client_index
=
client_index
self
.
outputs_queue
=
asyncio
.
Queue
[
Union
[
EngineCoreOutputs
,
Exception
]]()
try
:
...
...
@@ -854,12 +814,13 @@ class AsyncMPClient(MPClient):
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
utility_results
[
call_id
]
=
future
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
*
self
.
encoder
.
encode
(
(
call_id
,
method
,
args
)))
(
self
.
client_index
,
call_id
,
method
,
args
)))
await
self
.
_send_input_message
(
message
,
engine
,
args
)
self
.
_ensure_output_queue_task
()
return
await
future
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
request
.
client_index
=
self
.
client_index
await
self
.
_send_input
(
EngineCoreRequestType
.
ADD
,
request
)
self
.
_ensure_output_queue_task
()
...
...
@@ -921,17 +882,120 @@ class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
):
self
.
current_wave
=
0
self
.
engines_running
=
False
# To route aborts to the correct engine.
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
)
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
client_addresses
,
client_index
)
assert
len
(
self
.
core_engines
)
>
1
# List of [waiting, running] pair per engine.
self
.
lb_engines
:
list
[
list
[
int
]]
=
[]
self
.
first_req_sock_addr
=
get_open_zmq_inproc_path
()
self
.
first_req_send_socket
=
self
.
resources
.
first_req_send_socket
=
(
make_zmq_socket
(
self
.
ctx
,
self
.
first_req_sock_addr
,
zmq
.
PAIR
,
bind
=
True
))
try
:
# If we are running in an asyncio event loop, start the stats task.
# Otherwise, it will be started lazily.
asyncio
.
get_running_loop
()
self
.
_ensure_stats_update_task
()
except
RuntimeError
:
pass
def
_ensure_stats_update_task
(
self
):
resources
=
self
.
resources
if
resources
.
stats_update_task
is
not
None
:
return
assert
self
.
stats_update_address
is
not
None
async
def
run_engine_stats_update_task
():
with
make_zmq_socket
(
self
.
ctx
,
self
.
stats_update_address
,
zmq
.
XSUB
)
as
socket
,
make_zmq_socket
(
self
.
ctx
,
self
.
first_req_sock_addr
,
zmq
.
PAIR
,
bind
=
False
)
as
first_req_rcv_socket
:
# Send subscription message.
await
socket
.
send
(
b
'
\x01
'
)
poller
=
zmq
.
asyncio
.
Poller
()
poller
.
register
(
socket
,
zmq
.
POLLIN
)
poller
.
register
(
first_req_rcv_socket
,
zmq
.
POLLIN
)
while
True
:
events
=
await
poller
.
poll
()
if
not
self
.
engines_running
and
len
(
events
)
==
2
or
(
events
[
0
][
0
]
==
first_req_rcv_socket
):
# Send a message to notify the coordinator that
# we're sending a request while the engines are
# paused, so that it can wake the others up
# (to run dummy EP loop).
self
.
engines_running
=
True
buf
=
first_req_rcv_socket
.
recv
(
flags
=
zmq
.
NOBLOCK
).
result
()
target_eng_index
=
int
.
from_bytes
(
buf
,
"little"
)
msg
=
msgspec
.
msgpack
.
encode
(
(
target_eng_index
,
self
.
current_wave
))
await
socket
.
send
(
msg
)
buf
=
None
while
True
:
# Drain all stats events (we only care about latest).
future
:
asyncio
.
Future
[
bytes
]
=
socket
.
recv
(
flags
=
zmq
.
NOBLOCK
)
if
isinstance
(
future
.
exception
(),
zmq
.
Again
):
break
buf
=
future
.
result
()
if
buf
is
None
:
continue
# Update local load-balancing state.
counts
,
wave
,
running
=
msgspec
.
msgpack
.
decode
(
buf
)
self
.
current_wave
=
wave
self
.
engines_running
=
running
self
.
lb_engines
=
counts
resources
.
stats_update_task
=
asyncio
.
create_task
(
run_engine_stats_update_task
())
def
get_core_engine_for_request
(
self
)
->
CoreEngine
:
if
not
self
.
lb_engines
:
return
self
.
core_engines
[
0
]
# TODO use P2C alg for larger DP sizes
num_engines
=
len
(
self
.
lb_engines
)
min_counts
=
[
sys
.
maxsize
,
sys
.
maxsize
]
eng_index
=
0
for
i
in
range
(
num_engines
):
# Start from client_index to help with balancing when engines
# are empty.
idx
=
(
self
.
client_index
+
i
)
%
num_engines
counts
=
self
.
lb_engines
[
idx
]
if
counts
<
min_counts
:
min_counts
=
counts
eng_index
=
idx
# Adjust local counts for better balancing between stats updates
# from the coordinator (which happen every 100ms).
if
min_counts
[
0
]:
min_counts
[
0
]
+=
1
else
:
min_counts
[
1
]
+=
1
return
self
.
core_engines
[
eng_index
]
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
# Only the result from the first engine is returned.
return
(
await
asyncio
.
gather
(
*
[
...
...
@@ -940,62 +1004,30 @@ class DPAsyncMPClient(AsyncMPClient):
]))[
0
]
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
self
.
_ensure_stats_update_task
()
request
.
current_wave
=
self
.
current_wave
request
.
client_index
=
self
.
client_index
chosen_engine
=
self
.
get_core_engine_for_request
()
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
chosen_engine
.
num_reqs_in_flight
+=
1
to_await
=
self
.
_send_input
(
EngineCoreRequestType
.
ADD
,
request
,
chosen_engine
)
if
not
self
.
engines_running
:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self
.
engines_running
=
True
to_await
=
asyncio
.
gather
(
to_await
,
# type: ignore[assignment]
*
self
.
_start_wave_coros
(
exclude_index
=
chosen_engine
.
index
))
# Notify coordinator that we're sending a request
await
self
.
first_req_send_socket
.
send
(
chosen_engine
.
identity
)
await
to_await
self
.
_ensure_output_queue_task
()
def
get_core_engine_for_request
(
self
)
->
CoreEngine
:
return
min
(
self
.
core_engines
,
key
=
lambda
e
:
e
.
num_reqs_in_flight
)
@
staticmethod
async
def
process_engine_outputs
(
self
:
"DPAsyncMPClient"
,
outputs
:
EngineCoreOutputs
):
if
self
.
reqs_in_flight
:
for
req_id
in
outputs
.
finished_requests
or
():
if
engine
:
=
self
.
reqs_in_flight
.
pop
(
req_id
,
None
):
engine
.
num_reqs_in_flight
-=
1
if
outputs
.
wave_complete
is
not
None
:
# Current wave is complete, move to next wave number
# and mark engines as paused.
if
self
.
current_wave
<=
outputs
.
wave_complete
:
self
.
current_wave
=
outputs
.
wave_complete
+
1
self
.
engines_running
=
False
elif
outputs
.
start_wave
is
not
None
and
(
outputs
.
start_wave
>
self
.
current_wave
or
(
outputs
.
start_wave
==
self
.
current_wave
and
not
self
.
engines_running
)):
# Engine received request for a non-current wave so we must ensure
# that other engines progress to the next wave.
self
.
current_wave
=
outputs
.
start_wave
self
.
engines_running
=
True
await
asyncio
.
gather
(
*
self
.
_start_wave_coros
(
exclude_index
=
outputs
.
engine_index
))
def
_start_wave_coros
(
self
,
exclude_index
:
int
)
->
list
[
Awaitable
[
None
]]:
logger
.
debug
(
"Sending start DP wave %d."
,
self
.
current_wave
)
return
[
self
.
_send_input
(
EngineCoreRequestType
.
START_DP_WAVE
,
self
.
current_wave
,
engine
)
for
engine
in
self
.
core_engines
if
engine
.
index
!=
exclude_index
]
if
outputs
.
finished_requests
and
self
.
reqs_in_flight
:
for
req_id
in
outputs
.
finished_requests
:
self
.
reqs_in_flight
.
pop
(
req_id
,
None
)
async
def
abort_requests_async
(
self
,
request_ids
:
list
[
str
])
->
None
:
if
not
request_ids
:
...
...
vllm/v1/metrics/loggers.py
View file @
2dbe8c07
...
...
@@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.prometheus
import
unregister_vllm_metrics
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingLogging
,
SpecDecodingProm
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5.0
StatLoggerFactory
=
Callable
[[
VllmConfig
,
int
],
"StatLoggerBase"
]
...
...
@@ -35,7 +34,7 @@ class StatLoggerBase(ABC):
...
@
abstractmethod
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
...
...
...
@@ -78,20 +77,22 @@ class LoggingStatLogger(StatLoggerBase):
# Compute summary metrics for tracked stats
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
self
.
last_log_time
))
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
"""Log Stats to standard output."""
if
iteration_stats
:
self
.
_track_iteration_stats
(
iteration_stats
)
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
if
scheduler_stats
is
not
None
:
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_logging
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_logging
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
self
.
last_scheduler_stats
=
scheduler_stats
self
.
last_scheduler_stats
=
scheduler_stats
def
log
(
self
):
now
=
time
.
monotonic
()
...
...
@@ -131,10 +132,11 @@ class LoggingStatLogger(StatLoggerBase):
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
def
log_engine_initialized
(
self
):
logger
.
info
(
"vllm cache_config_info with initialization "
\
"after num_gpu_blocks is: %d"
,
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
)
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
logger
.
info
(
"Engine %03d: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d"
,
self
.
engine_index
,
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
)
class
PrometheusStatLogger
(
StatLoggerBase
):
...
...
@@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase):
_spec_decoding_cls
=
SpecDecodingProm
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
self
.
_unregister_vllm_metrics
()
unregister_vllm_metrics
()
self
.
vllm_config
=
vllm_config
self
.
engine_index
=
engine_index
# Use this flag to hide metrics that were deprecated in
...
...
@@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_scheduler_running
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_running"
,
documentation
=
"Number of requests in model execution batches."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
gauge_scheduler_waiting
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_waiting"
,
documentation
=
"Number of requests waiting to be processed."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
#
...
...
@@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_gpu_cache_usage
=
self
.
_gauge_cls
(
name
=
"vllm:gpu_cache_usage_perc"
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_gpu_prefix_cache_queries
=
self
.
_counter_cls
(
...
...
@@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase):
buckets
=
build_1_2_5_buckets
(
max_model_len
),
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
# See: https://github.com/vllm-project/vllm/pull/18053
self
.
histogram_iteration_tokens
=
\
self
.
_histogram_cls
(
name
=
"vllm:iteration_tokens_total"
,
...
...
@@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase):
#
# LoRA metrics
#
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
self
.
gauge_lora_info
:
Optional
[
prometheus_client
.
Gauge
]
=
None
if
vllm_config
.
lora_config
is
not
None
:
self
.
labelname_max_lora
=
"max_lora"
...
...
@@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
_gauge_cls
(
name
=
"vllm:lora_requests_info"
,
documentation
=
"Running stats on lora requests."
,
multiprocess_mode
=
"sum"
,
labelnames
=
[
self
.
labelname_max_lora
,
self
.
labelname_waiting_lora_adapters
,
self
.
labelname_running_lora_adapters
,
])
],
)
def
log_metrics_info
(
self
,
type
:
str
,
config_obj
:
SupportsMetricsInfo
):
metrics_info
=
config_obj
.
metrics_info
()
metrics_info
[
"engine"
]
=
self
.
engine_index
...
...
@@ -372,25 +387,28 @@ class PrometheusStatLogger(StatLoggerBase):
info_gauge
=
self
.
_gauge_cls
(
name
=
name
,
documentation
=
documentation
,
labelnames
=
metrics_info
.
keys
()).
labels
(
**
metrics_info
)
multiprocess_mode
=
"mostrecent"
,
labelnames
=
metrics_info
.
keys
(),
).
labels
(
**
metrics_info
)
info_gauge
.
set
(
1
)
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
"""Log to prometheus."""
self
.
gauge_scheduler_running
.
set
(
scheduler_stats
.
num_running_reqs
)
self
.
gauge_scheduler_waiting
.
set
(
scheduler_stats
.
num_waiting_reqs
)
if
scheduler_stats
is
not
None
:
self
.
gauge_scheduler_running
.
set
(
scheduler_stats
.
num_running_reqs
)
self
.
gauge_scheduler_waiting
.
set
(
scheduler_stats
.
num_waiting_reqs
)
self
.
gauge_gpu_cache_usage
.
set
(
scheduler_stats
.
gpu_cache_usage
)
self
.
gauge_gpu_cache_usage
.
set
(
scheduler_stats
.
gpu_cache_usage
)
self
.
counter_gpu_prefix_cache_queries
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
queries
)
self
.
counter_gpu_prefix_cache_hits
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
hits
)
self
.
counter_gpu_prefix_cache_queries
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
queries
)
self
.
counter_gpu_prefix_cache_hits
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
hits
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_prom
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_prom
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
if
iteration_stats
is
None
:
return
...
...
@@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_lora_info
.
labels
(
**
lora_info_labels
)
\
.
set_to_current_time
()
@
staticmethod
def
_unregister_vllm_metrics
():
# Unregister any existing vLLM collectors (for CI/CD
for
collector
in
list
(
prometheus_client
.
REGISTRY
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
prometheus_client
.
REGISTRY
.
unregister
(
collector
)
def
log_engine_initialized
(
self
):
self
.
log_metrics_info
(
"cache_config"
,
self
.
vllm_config
.
cache_config
)
...
...
vllm/v1/metrics/prometheus.py
0 → 100644
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
os
import
tempfile
from
typing
import
Optional
from
prometheus_client
import
REGISTRY
,
CollectorRegistry
,
multiprocess
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
# Global temporary directory for prometheus multiprocessing
_prometheus_multiproc_dir
:
Optional
[
tempfile
.
TemporaryDirectory
]
=
None
def
setup_multiprocess_prometheus
():
"""Set up prometheus multiprocessing directory if not already configured.
"""
global
_prometheus_multiproc_dir
if
"PROMETHEUS_MULTIPROC_DIR"
not
in
os
.
environ
:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
_prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
()
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
=
_prometheus_multiproc_dir
.
name
logger
.
debug
(
"Created PROMETHEUS_MULTIPROC_DIR at %s"
,
_prometheus_multiproc_dir
.
name
)
else
:
logger
.
warning
(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup."
)
def
get_prometheus_registry
():
"""Get the appropriate prometheus registry based on multiprocessing
configuration.
Returns:
Registry: A prometheus registry
"""
if
os
.
getenv
(
"PROMETHEUS_MULTIPROC_DIR"
)
is
not
None
:
logger
.
debug
(
"Using multiprocess registry for prometheus metrics"
)
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
return
registry
return
REGISTRY
def
unregister_vllm_metrics
():
"""Unregister any existing vLLM collectors from the prometheus registry.
This is useful for testing and CI/CD where metrics may be registered
multiple times across test runs.
Also, in case of multiprocess, we need to unregister the metrics from the
global registry.
"""
registry
=
REGISTRY
# Unregister any existing vLLM collectors
for
collector
in
list
(
registry
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
registry
.
unregister
(
collector
)
def
shutdown_prometheus
():
"""Shutdown prometheus metrics."""
try
:
pid
=
os
.
getpid
()
multiprocess
.
mark_process_dead
(
pid
)
logger
.
debug
(
"Marked Prometheus metrics for process %d as dead"
,
pid
)
except
Exception
as
e
:
logger
.
error
(
"Error during metrics cleanup: %s"
,
str
(
e
))
vllm/v1/request.py
View file @
2dbe8c07
...
...
@@ -26,12 +26,13 @@ class Request:
multi_modal_placeholders
:
Optional
[
list
[
PlaceholderRange
]],
sampling_params
:
SamplingParams
,
eos_token_id
:
Optional
[
int
],
arrival_time
:
float
,
client_index
:
int
=
0
,
lora_request
:
Optional
[
"LoRARequest"
]
=
None
,
structured_output_request
:
Optional
[
"StructuredOutputRequest"
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
)
->
None
:
self
.
request_id
=
request_id
self
.
client_index
=
client_index
self
.
sampling_params
=
sampling_params
# Because of LoRA, the eos token id can be different for each request.
self
.
eos_token_id
=
eos_token_id
...
...
@@ -90,13 +91,13 @@ class Request:
return
cls
(
request_id
=
request
.
request_id
,
client_index
=
request
.
client_index
,
prompt_token_ids
=
request
.
prompt_token_ids
,
multi_modal_inputs
=
request
.
mm_inputs
,
multi_modal_hashes
=
request
.
mm_hashes
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
sampling_params
=
request
.
sampling_params
,
eos_token_id
=
request
.
eos_token_id
,
arrival_time
=
request
.
arrival_time
,
lora_request
=
request
.
lora_request
,
structured_output_request
=
StructuredOutputRequest
(
sampling_params
=
request
.
sampling_params
),
...
...
vllm/v1/utils.py
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
os
import
argparse
import
multiprocessing
import
time
import
weakref
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
multiprocessing
import
Process
,
connection
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
)
from
multiprocessing.process
import
BaseProcess
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
)
import
msgspec
import
torch
import
zmq
from
vllm.config
import
VllmConfig
from
vllm.config
import
CacheConfig
,
ParallelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
get_mp_context
,
kill_process_tree
from
vllm.utils
import
(
get_mp_context
,
get_open_port
,
get_open_zmq_ipc_path
,
get_tcp_uri
,
kill_process_tree
)
from
vllm.v1.executor.abstract
import
Executor
if
TYPE_CHECKING
:
from
vllm.attention.layer
import
Attention
from
vllm.v1.engine.coordinator
import
DPCoordinator
logger
=
init_logger
(
__name__
)
T
=
TypeVar
(
"T"
)
STARTUP_POLL_PERIOD_MS
=
10000
class
ConstantList
(
Generic
[
T
],
Sequence
):
...
...
@@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence):
return
f
"ConstantList(
{
self
.
_x
}
)"
def
get_engine_client_zmq_addr
(
local_only
:
bool
,
host
:
str
,
port
:
int
=
0
)
->
str
:
return
get_open_zmq_ipc_path
()
if
local_only
else
(
get_tcp_uri
(
host
,
port
or
get_open_port
()))
class
APIServerProcessManager
:
"""Manages a group of API server processes.
Handles creation, monitoring, and termination of API server worker
processes. Also monitors extra processes to check if they are healthy.
"""
def
__init__
(
self
,
target_server_fn
:
Callable
,
listen_address
:
str
,
sock
:
Any
,
args
:
argparse
.
Namespace
,
num_servers
:
int
,
input_addresses
:
list
[
str
],
output_addresses
:
list
[
str
],
stats_update_address
:
Optional
[
str
]
=
None
,
):
"""Initialize and start API server worker processes.
Args:
target_server_fn: Function to call for each API server process
listen_address: Address to listen for client connections
sock: Socket for client connections
args: Command line arguments
num_servers: Number of API server processes to start
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
"""
self
.
listen_address
=
listen_address
self
.
sock
=
sock
self
.
args
=
args
# Start API servers
spawn_context
=
multiprocessing
.
get_context
(
"spawn"
)
self
.
processes
:
list
[
BaseProcess
]
=
[]
for
i
,
in_addr
,
out_addr
in
zip
(
range
(
num_servers
),
input_addresses
,
output_addresses
):
client_config
=
{
"input_address"
:
in_addr
,
"output_address"
:
out_addr
,
"client_index"
:
i
}
if
stats_update_address
is
not
None
:
client_config
[
"stats_update_address"
]
=
stats_update_address
proc
=
spawn_context
.
Process
(
target
=
target_server_fn
,
name
=
f
"ApiServer_
{
i
}
"
,
args
=
(
listen_address
,
sock
,
args
,
client_config
))
self
.
processes
.
append
(
proc
)
proc
.
start
()
logger
.
info
(
"Started %d API server processes"
,
len
(
self
.
processes
))
# Shutdown only the API server processes on garbage collection
# The extra processes are managed by their owners
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
)
def
close
(
self
)
->
None
:
self
.
_finalizer
()
class
CoreEngineProcManager
:
"""
Utility class to handle creation, readiness, and shutdown
...
...
@@ -109,7 +191,7 @@ class CoreEngineProcManager:
local_start_index
:
int
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
input
_address
:
str
,
handshake
_address
:
str
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
):
...
...
@@ -117,12 +199,12 @@ class CoreEngineProcManager:
common_kwargs
=
{
"vllm_config"
:
vllm_config
,
"on_head_node"
:
on_head_node
,
"
input
_address"
:
input
_address
,
"
handshake
_address"
:
handshake
_address
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
}
self
.
processes
:
list
[
Process
]
=
[]
self
.
processes
:
list
[
Base
Process
]
=
[]
for
index
in
range
(
local_engine_count
):
local_index
=
local_start_index
+
index
global_index
=
start_index
+
index
...
...
@@ -135,8 +217,7 @@ class CoreEngineProcManager:
"local_dp_rank"
:
local_index
,
}))
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
,
input_address
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
)
try
:
for
proc
in
self
.
processes
:
proc
.
start
()
...
...
@@ -164,9 +245,199 @@ class CoreEngineProcManager:
}
class
CoreEngineState
(
Enum
):
NEW
=
auto
()
CONNECTED
=
auto
()
READY
=
auto
()
class
CoreEngine
:
"""One per data parallel rank."""
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
self
.
local
=
local
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
2
,
"little"
)
self
.
state
=
CoreEngineState
.
NEW
@
dataclass
class
EngineZmqAddresses
:
# ZMQ input socket addresses for each front-end client (requests)
inputs
:
list
[
str
]
# ZMQ output socket addresses for each front-end client (responses)
outputs
:
list
[
str
]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input
:
Optional
[
str
]
=
None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output
:
Optional
[
str
]
=
None
@
dataclass
class
EngineHandshakeMetadata
:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses
:
EngineZmqAddresses
parallel_config
:
dict
[
str
,
Union
[
int
,
str
]]
def
wait_for_engine_startup
(
handshake_socket
:
zmq
.
Socket
,
addresses
:
EngineZmqAddresses
,
core_engines
:
list
[
CoreEngine
],
parallel_config
:
ParallelConfig
,
cache_config
:
CacheConfig
,
proc_manager
:
Optional
[
CoreEngineProcManager
],
coord_process
:
Optional
[
Process
],
):
# Wait for engine core process(es) to send ready messages.
local_count
=
parallel_config
.
data_parallel_size_local
remote_count
=
len
(
core_engines
)
-
local_count
# [local, remote] counts
conn_pending
,
start_pending
=
[
local_count
,
remote_count
],
[
0
,
0
]
poller
=
zmq
.
Poller
()
poller
.
register
(
handshake_socket
,
zmq
.
POLLIN
)
if
proc_manager
is
not
None
:
for
sentinel
in
proc_manager
.
sentinels
():
poller
.
register
(
sentinel
,
zmq
.
POLLIN
)
if
coord_process
is
not
None
:
poller
.
register
(
coord_process
.
sentinel
,
zmq
.
POLLIN
)
while
any
(
conn_pending
)
or
any
(
start_pending
):
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
if
not
events
:
if
any
(
conn_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect."
,
*
conn_pending
)
if
any
(
start_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to start."
,
*
start_pending
)
continue
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
handshake_socket
:
# One of the local core processes exited.
finished
=
proc_manager
.
finished_procs
()
if
proc_manager
else
{}
if
coord_process
is
not
None
and
coord_process
.
exitcode
is
not
None
:
finished
[
coord_process
.
name
]
=
coord_process
.
exitcode
raise
RuntimeError
(
"Engine core initialization failed. "
"See root cause above. "
f
"Failed core proc(s):
{
finished
}
"
)
# Receive HELLO and READY messages from the input socket.
eng_identity
,
ready_msg_bytes
=
handshake_socket
.
recv_multipart
()
eng_index
=
int
.
from_bytes
(
eng_identity
,
"little"
)
engine
=
next
((
e
for
e
in
core_engines
if
e
.
identity
==
eng_identity
),
None
)
if
engine
is
None
:
raise
RuntimeError
(
f
"Message from engine with unexpected data "
f
"parallel rank:
{
eng_index
}
"
)
msg
=
msgspec
.
msgpack
.
decode
(
ready_msg_bytes
)
status
,
local
=
msg
[
"status"
],
msg
[
"local"
]
if
local
!=
engine
.
local
:
raise
RuntimeError
(
f
"
{
status
}
message from "
f
"
{
'local'
if
local
else
'remote'
}
"
f
"engine
{
eng_index
}
, expected it to be "
f
"
{
'local'
if
engine
.
local
else
'remote'
}
"
)
if
status
==
"HELLO"
and
engine
.
state
==
CoreEngineState
.
NEW
:
# Send init message with DP config info.
init_message
=
msgspec
.
msgpack
.
encode
(
EngineHandshakeMetadata
(
addresses
=
addresses
,
parallel_config
=
{
"data_parallel_master_ip"
:
parallel_config
.
data_parallel_master_ip
,
"data_parallel_master_port"
:
parallel_config
.
data_parallel_master_port
,
"data_parallel_size"
:
parallel_config
.
data_parallel_size
,
}))
handshake_socket
.
send_multipart
((
eng_identity
,
init_message
),
copy
=
False
)
conn_pending
[
0
if
local
else
1
]
-=
1
start_pending
[
0
if
local
else
1
]
+=
1
engine
.
state
=
CoreEngineState
.
CONNECTED
elif
status
==
"READY"
and
(
engine
.
state
==
CoreEngineState
.
CONNECTED
):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
or
0
num_gpu_blocks
+=
msg
[
"num_gpu_blocks"
]
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
start_pending
[
0
if
local
else
1
]
-=
1
engine
.
state
=
CoreEngineState
.
READY
else
:
raise
RuntimeError
(
f
"Unexpected
{
status
}
message for "
f
"
{
'local'
if
local
else
'remote'
}
engine "
f
"
{
eng_index
}
in
{
engine
.
state
}
state."
)
logger
.
debug
(
"%s from %s core engine process %s."
,
status
,
"local"
if
local
else
"remote"
,
eng_index
)
def
wait_for_completion_or_failure
(
api_server_manager
:
APIServerProcessManager
,
local_engine_manager
:
Optional
[
CoreEngineProcManager
]
=
None
,
coordinator
:
Optional
[
"DPCoordinator"
]
=
None
)
->
None
:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
"""
try
:
logger
.
info
(
"Waiting for API servers to complete ..."
)
# Create a mapping of sentinels to their corresponding processes
# for efficient lookup
sentinel_to_proc
:
dict
[
Any
,
BaseProcess
]
=
{
proc
.
sentinel
:
proc
for
proc
in
api_server_manager
.
processes
}
if
coordinator
:
sentinel_to_proc
[
coordinator
.
proc
.
sentinel
]
=
coordinator
.
proc
if
local_engine_manager
:
for
proc
in
local_engine_manager
.
processes
:
sentinel_to_proc
[
proc
.
sentinel
]
=
proc
# Check if any process terminates
while
sentinel_to_proc
:
# Wait for any process to terminate
ready_sentinels
:
list
[
Any
]
=
connection
.
wait
(
sentinel_to_proc
)
# Process any terminated processes
for
sentinel
in
ready_sentinels
:
proc
=
sentinel_to_proc
.
pop
(
sentinel
)
# Check if process exited with error
if
proc
.
exitcode
!=
0
:
raise
RuntimeError
(
f
"Process
{
proc
.
name
}
(PID:
{
proc
.
pid
}
) "
f
"died with exit code
{
proc
.
exitcode
}
"
)
except
KeyboardInterrupt
:
logger
.
info
(
"Received KeyboardInterrupt, shutting down API servers..."
)
except
Exception
as
e
:
logger
.
exception
(
"Exception occurred while running API servers: %s"
,
str
(
e
))
raise
finally
:
logger
.
info
(
"Terminating remaining processes ..."
)
api_server_manager
.
close
()
if
coordinator
:
coordinator
.
close
()
if
local_engine_manager
:
local_engine_manager
.
close
()
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the obje
decoup
ct.
def
shutdown
(
procs
:
list
[
Process
]
,
input_address
:
str
):
# else the gc cannot collect the object.
def
shutdown
(
procs
:
list
[
Base
Process
]):
# Shutdown the process.
for
proc
in
procs
:
if
proc
.
is_alive
():
...
...
@@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str):
if
proc
.
is_alive
()
and
(
pid
:
=
proc
.
pid
)
is
not
None
:
kill_process_tree
(
pid
)
# Remove zmq ipc socket files.
if
input_address
.
startswith
(
"ipc://"
):
socket_file
=
input_address
[
len
(
"ipc://"
):]
if
os
and
os
.
path
.
exists
(
socket_file
):
os
.
remove
(
socket_file
)
def
bind_kv_cache
(
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment