Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2dbe8c07
"vllm/vscode:/vscode.git/clone" did not exist on "45f526d65237d9073a5f3be166b306580687f210"
Unverified
Commit
2dbe8c07
authored
May 30, 2025
by
Nick Hill
Committed by
GitHub
May 30, 2025
Browse files
[Perf] API-server scaleout with many-to-many server-engine comms (#17546)
parent
84ec470f
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
818 additions
and
353 deletions
+818
-353
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+166
-87
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+244
-212
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+45
-34
vllm/v1/metrics/prometheus.py
vllm/v1/metrics/prometheus.py
+77
-0
vllm/v1/request.py
vllm/v1/request.py
+3
-2
vllm/v1/utils.py
vllm/v1/utils.py
+283
-18
No files found.
vllm/v1/engine/core.py
View file @
2dbe8c07
...
@@ -7,6 +7,7 @@ import threading
...
@@ -7,6 +7,7 @@ import threading
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
contextlib
import
ExitStack
from
inspect
import
isclass
,
signature
from
inspect
import
isclass
,
signature
from
logging
import
DEBUG
from
logging
import
DEBUG
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
...
@@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
...
@@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.config
import
(
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
maybe_register_config_serialize_by_value
)
from
vllm.utils
import
make_zmq_socket
,
resolve_obj_by_qualname
,
zmq_socket_ctx
from
vllm.utils
import
make_zmq_socket
,
resolve_obj_by_qualname
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
unify_kv_cache_configs
)
unify_kv_cache_configs
)
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.interface
import
SchedulerInterface
...
@@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
...
@@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
from
vllm.v1.engine.mm_input_cache
import
MirroredProcessingCache
from
vllm.v1.engine.mm_input_cache
import
MirroredProcessingCache
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.utils
import
EngineHandshakeMetadata
,
EngineZmqAddresses
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -211,7 +214,7 @@ class EngineCore:
...
@@ -211,7 +214,7 @@ class EngineCore:
# Re-raise exception
# Re-raise exception
raise
err
raise
err
def
step
(
self
)
->
tuple
[
EngineCoreOutputs
,
bool
]:
def
step
(
self
)
->
tuple
[
dict
[
int
,
EngineCoreOutputs
]
,
bool
]:
"""Schedule, execute, and make output.
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
Returns tuple of outputs and a flag indicating whether the model
...
@@ -221,10 +224,7 @@ class EngineCore:
...
@@ -221,10 +224,7 @@ class EngineCore:
# Check for any requests remaining in the scheduler - unfinished,
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
# or finished and not yet removed from the batch.
if
not
self
.
scheduler
.
has_requests
():
if
not
self
.
scheduler
.
has_requests
():
return
EngineCoreOutputs
(
return
{},
False
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
(),
),
False
scheduler_output
=
self
.
scheduler
.
schedule
()
scheduler_output
=
self
.
scheduler
.
schedule
()
model_output
=
self
.
execute_model
(
scheduler_output
)
model_output
=
self
.
execute_model
(
scheduler_output
)
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
...
@@ -234,7 +234,7 @@ class EngineCore:
...
@@ -234,7 +234,7 @@ class EngineCore:
scheduler_output
.
total_num_scheduled_tokens
>
0
)
scheduler_output
.
total_num_scheduled_tokens
>
0
)
def
step_with_batch_queue
(
def
step_with_batch_queue
(
self
)
->
tuple
[
Optional
[
EngineCoreOutputs
],
bool
]:
self
)
->
tuple
[
Optional
[
dict
[
int
,
EngineCoreOutputs
]
]
,
bool
]:
"""Schedule and execute batches with the batch queue.
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned.
Note that if nothing to output in this step, None is returned.
...
@@ -276,8 +276,8 @@ class EngineCore:
...
@@ -276,8 +276,8 @@ class EngineCore:
# Blocking until the first result is available.
# Blocking until the first result is available.
model_output
=
future
.
result
()
model_output
=
future
.
result
()
self
.
batch_queue
.
task_done
()
self
.
batch_queue
.
task_done
()
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
engine_core_outputs
=
(
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
scheduler_output
,
model_output
)
)
return
engine_core_outputs
,
scheduled_batch
return
engine_core_outputs
,
scheduled_batch
...
@@ -362,7 +362,7 @@ class EngineCoreProc(EngineCore):
...
@@ -362,7 +362,7 @@ class EngineCoreProc(EngineCore):
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
on_head_node
:
bool
,
input
_address
:
str
,
handshake
_address
:
str
,
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
engine_index
:
int
=
0
,
engine_index
:
int
=
0
,
...
@@ -375,65 +375,70 @@ class EngineCoreProc(EngineCore):
...
@@ -375,65 +375,70 @@ class EngineCoreProc(EngineCore):
# Create input socket.
# Create input socket.
input_ctx
=
zmq
.
Context
()
input_ctx
=
zmq
.
Context
()
identity
=
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
identity
=
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
input_socket
=
make_zmq_socket
(
input_ctx
,
with
make_zmq_socket
(
input_ctx
,
input_address
,
handshake_address
,
zmq
.
DEALER
,
zmq
.
DEALER
,
identity
=
identity
,
identity
=
identity
,
bind
=
False
)
linger
=
5000
,
try
:
bind
=
False
)
as
handshake_socket
:
# Register engine with front-end.
# Register engine with front-end.
output_address
=
self
.
startup_handshake
(
addresses
=
self
.
startup_handshake
(
handshake_socket
,
on_head_node
,
input_socket
,
on_head_node
,
vllm_config
.
parallel_config
)
vllm_config
.
parallel_config
)
self
.
client_count
=
len
(
addresses
.
outputs
)
# Update config which may have changed from the handshake.
# Update config which may have changed from the handshake.
vllm_config
.
__post_init__
()
vllm_config
.
__post_init__
()
# Set up data parallel environment.
# Set up data parallel environment.
self
.
has_coordinator
=
addresses
.
coordinator_output
is
not
None
self
.
_init_data_parallel
(
vllm_config
)
self
.
_init_data_parallel
(
vllm_config
)
# Initialize engine core and model.
# Initialize engine core and model.
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
executor_fail_callback
)
executor_fail_callback
)
self
.
engine_index
=
engine_index
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
self
.
step_with_batch_queue
)
self
.
engines_running
=
False
self
.
engines_running
=
False
self
.
last_counts
=
(
0
,
0
)
# Send ready message.
# Send ready message.
num_gpu_blocks
=
vllm_config
.
cache_config
.
num_gpu_blocks
num_gpu_blocks
=
vllm_config
.
cache_config
.
num_gpu_blocks
input
_socket
.
send
(
handshake
_socket
.
send
(
msgspec
.
msgpack
.
encode
({
msgspec
.
msgpack
.
encode
({
"status"
:
"READY"
,
"status"
:
"READY"
,
"local"
:
on_head_node
,
"local"
:
on_head_node
,
"num_gpu_blocks"
:
num_gpu_blocks
,
"num_gpu_blocks"
:
num_gpu_blocks
,
}))
}))
# Background Threads and Queues for IO. These enable us to
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# and to overlap some serialization/deserialization with the
# model forward pass.
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self
.
input_queue
=
input_queue
self
.
input_queue
=
input_queue
self
.
output_queue
=
queue
.
Queue
[
Union
[
EngineCoreOutputs
,
bytes
]]()
self
.
output_queue
=
queue
.
Queue
[
Union
[
tuple
[
int
,
EngineCoreOutputs
],
threading
.
Thread
(
target
=
self
.
process_input_socket
,
bytes
]]()
args
=
(
input_socket
,
),
threading
.
Thread
(
target
=
self
.
process_input_sockets
,
daemon
=
True
).
start
()
args
=
(
addresses
.
inputs
,
addresses
.
coordinator_input
,
input_socket
=
None
identity
),
self
.
output_thread
=
threading
.
Thread
(
daemon
=
True
).
start
()
target
=
self
.
process_output_socket
,
self
.
output_thread
=
threading
.
Thread
(
args
=
(
output_address
,
engine_index
),
target
=
self
.
process_output_sockets
,
daemon
=
True
)
args
=
(
addresses
.
outputs
,
addresses
.
coordinator_output
,
self
.
output_thread
.
start
()
engine_index
),
finally
:
daemon
=
True
)
if
input_socket
is
not
None
:
self
.
output_thread
.
start
()
input_socket
.
close
(
linger
=
0
)
@
staticmethod
@
staticmethod
def
startup_handshake
(
input_socket
:
zmq
.
Socket
,
on_head_node
:
bool
,
def
startup_handshake
(
parallel_config
:
ParallelConfig
)
->
str
:
handshake_socket
:
zmq
.
Socket
,
on_head_node
:
bool
,
parallel_config
:
ParallelConfig
)
->
EngineZmqAddresses
:
# Send registration message.
# Send registration message.
input
_socket
.
send
(
handshake
_socket
.
send
(
msgspec
.
msgpack
.
encode
({
msgspec
.
msgpack
.
encode
({
"status"
:
"HELLO"
,
"status"
:
"HELLO"
,
"local"
:
on_head_node
,
"local"
:
on_head_node
,
...
@@ -441,22 +446,20 @@ class EngineCoreProc(EngineCore):
...
@@ -441,22 +446,20 @@ class EngineCoreProc(EngineCore):
# Receive initialization message.
# Receive initialization message.
logger
.
info
(
"Waiting for init message from front-end."
)
logger
.
info
(
"Waiting for init message from front-end."
)
if
not
input
_socket
.
poll
(
timeout
=
HANDSHAKE_TIMEOUT_MINS
*
60
*
1
000
):
if
not
handshake
_socket
.
poll
(
timeout
=
HANDSHAKE_TIMEOUT_MINS
*
60
_
000
):
raise
RuntimeError
(
"Did not receive response from front-end "
raise
RuntimeError
(
"Did not receive response from front-end "
f
"process within
{
HANDSHAKE_TIMEOUT_MINS
}
"
f
"process within
{
HANDSHAKE_TIMEOUT_MINS
}
"
f
"minutes"
)
f
"minutes"
)
init_bytes
=
input_socket
.
recv
()
init_bytes
=
handshake_socket
.
recv
()
init_message
=
msgspec
.
msgpack
.
decode
(
init_bytes
)
init_message
:
EngineHandshakeMetadata
=
msgspec
.
msgpack
.
decode
(
init_bytes
,
type
=
EngineHandshakeMetadata
)
logger
.
debug
(
"Received init message: %s"
,
init_message
)
logger
.
debug
(
"Received init message: %s"
,
init_message
)
output_socket_address
=
init_message
[
"output_socket_address"
]
received_parallel_config
=
init_message
.
parallel_config
#TBD(nick) maybe replace IP with configured head node address
received_parallel_config
=
init_message
[
"parallel_config"
]
for
key
,
value
in
received_parallel_config
.
items
():
for
key
,
value
in
received_parallel_config
.
items
():
setattr
(
parallel_config
,
key
,
value
)
setattr
(
parallel_config
,
key
,
value
)
return
output_socket_
address
return
init_message
.
address
es
@
staticmethod
@
staticmethod
def
run_engine_core
(
*
args
,
def
run_engine_core
(
*
args
,
...
@@ -528,7 +531,7 @@ class EngineCoreProc(EngineCore):
...
@@ -528,7 +531,7 @@ class EngineCoreProc(EngineCore):
"""Exits when an engine step needs to be performed."""
"""Exits when an engine step needs to be performed."""
waited
=
False
waited
=
False
while
not
self
.
engines_running
and
not
(
self
.
scheduler
.
has_requests
()
)
:
while
not
self
.
engines_running
and
not
self
.
scheduler
.
has_requests
():
if
logger
.
isEnabledFor
(
DEBUG
)
and
self
.
input_queue
.
empty
():
if
logger
.
isEnabledFor
(
DEBUG
)
and
self
.
input_queue
.
empty
():
logger
.
debug
(
"EngineCore waiting for work."
)
logger
.
debug
(
"EngineCore waiting for work."
)
waited
=
True
waited
=
True
...
@@ -549,8 +552,8 @@ class EngineCoreProc(EngineCore):
...
@@ -549,8 +552,8 @@ class EngineCoreProc(EngineCore):
# Step the engine core.
# Step the engine core.
outputs
,
model_executed
=
self
.
step_fn
()
outputs
,
model_executed
=
self
.
step_fn
()
# Put EngineCoreOutputs into the output queue.
# Put EngineCoreOutputs into the output queue.
i
f
output
s
i
s
not
None
:
f
or
output
i
n
(
outputs
.
items
()
if
outputs
else
())
:
self
.
output_queue
.
put_nowait
(
output
s
)
self
.
output_queue
.
put_nowait
(
output
)
return
model_executed
return
model_executed
...
@@ -563,7 +566,7 @@ class EngineCoreProc(EngineCore):
...
@@ -563,7 +566,7 @@ class EngineCoreProc(EngineCore):
elif
request_type
==
EngineCoreRequestType
.
ABORT
:
elif
request_type
==
EngineCoreRequestType
.
ABORT
:
self
.
abort_requests
(
request
)
self
.
abort_requests
(
request
)
elif
request_type
==
EngineCoreRequestType
.
UTILITY
:
elif
request_type
==
EngineCoreRequestType
.
UTILITY
:
call_id
,
method_name
,
args
=
request
client_idx
,
call_id
,
method_name
,
args
=
request
output
=
UtilityOutput
(
call_id
)
output
=
UtilityOutput
(
call_id
)
try
:
try
:
method
=
getattr
(
self
,
method_name
)
method
=
getattr
(
self
,
method_name
)
...
@@ -574,7 +577,7 @@ class EngineCoreProc(EngineCore):
...
@@ -574,7 +577,7 @@ class EngineCoreProc(EngineCore):
output
.
failure_message
=
(
f
"Call to
{
method_name
}
method"
output
.
failure_message
=
(
f
"Call to
{
method_name
}
method"
f
" failed:
{
str
(
e
)
}
"
)
f
" failed:
{
str
(
e
)
}
"
)
self
.
output_queue
.
put_nowait
(
self
.
output_queue
.
put_nowait
(
EngineCoreOutputs
(
utility_output
=
output
))
(
client_idx
,
EngineCoreOutputs
(
utility_output
=
output
))
)
elif
request_type
==
EngineCoreRequestType
.
EXECUTOR_FAILED
:
elif
request_type
==
EngineCoreRequestType
.
EXECUTOR_FAILED
:
raise
RuntimeError
(
"Executor failed."
)
raise
RuntimeError
(
"Executor failed."
)
else
:
else
:
...
@@ -607,27 +610,68 @@ class EngineCoreProc(EngineCore):
...
@@ -607,27 +610,68 @@ class EngineCoreProc(EngineCore):
logger
.
fatal
(
"vLLM shutdown signal from EngineCore failed "
logger
.
fatal
(
"vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue."
)
"to send. Please report this issue."
)
def
process_input_socket
(
self
,
input_socket
:
zmq
.
Socket
):
def
process_input_sockets
(
self
,
input_addresses
:
list
[
str
],
coord_input_address
:
Optional
[
str
],
identity
:
bytes
):
"""Input socket IO thread."""
"""Input socket IO thread."""
# Msgpack serialization decoding.
# Msgpack serialization decoding.
add_request_decoder
=
MsgpackDecoder
(
EngineCoreRequest
)
add_request_decoder
=
MsgpackDecoder
(
EngineCoreRequest
)
generic_decoder
=
MsgpackDecoder
()
generic_decoder
=
MsgpackDecoder
()
while
True
:
with
ExitStack
()
as
stack
,
zmq
.
Context
()
as
ctx
:
# (RequestType, RequestData)
input_sockets
=
[
type_frame
,
*
data_frames
=
input_socket
.
recv_multipart
(
copy
=
False
)
stack
.
enter_context
(
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
make_zmq_socket
(
ctx
,
input_address
,
# Deserialize the request data.
zmq
.
DEALER
,
decoder
=
add_request_decoder
if
(
identity
=
identity
,
request_type
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
bind
=
False
))
request
=
decoder
.
decode
(
data_frames
)
for
input_address
in
input_addresses
]
# Push to input queue for core busy loop.
if
coord_input_address
is
None
:
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
coord_socket
=
None
else
:
coord_socket
=
stack
.
enter_context
(
make_zmq_socket
(
ctx
,
coord_input_address
,
zmq
.
XSUB
,
identity
=
identity
,
bind
=
False
))
# Send subscription message to coordinator.
coord_socket
.
send
(
b
'
\x01
'
)
# Register sockets with poller.
poller
=
zmq
.
Poller
()
for
input_socket
in
input_sockets
:
# Send initial message to each input socket - this is required
# before the front-end ROUTER socket can send input messages
# back to us.
input_socket
.
send
(
b
''
)
poller
.
register
(
input_socket
,
zmq
.
POLLIN
)
if
coord_socket
is
not
None
:
poller
.
register
(
coord_socket
,
zmq
.
POLLIN
)
def
process_output_socket
(
self
,
output_path
:
str
,
engine_index
:
int
):
while
True
:
for
input_socket
,
_
in
poller
.
poll
():
# (RequestType, RequestData)
type_frame
,
*
data_frames
=
input_socket
.
recv_multipart
(
copy
=
False
)
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
decoder
=
add_request_decoder
if
(
request_type
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
request
=
decoder
.
decode
(
data_frames
)
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
def
process_output_sockets
(
self
,
output_paths
:
list
[
str
],
coord_output_path
:
Optional
[
str
],
engine_index
:
int
):
"""Output socket IO thread."""
"""Output socket IO thread."""
# Msgpack serialization encoding.
# Msgpack serialization encoding.
...
@@ -641,30 +685,49 @@ class EngineCoreProc(EngineCore):
...
@@ -641,30 +685,49 @@ class EngineCoreProc(EngineCore):
# We must set linger to ensure the ENGINE_CORE_DEAD
# We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket.
# message is sent prior to closing the socket.
with
zmq_socket_ctx
(
output_path
,
zmq
.
constants
.
PUSH
,
with
ExitStack
()
as
stack
,
zmq
.
Context
()
as
ctx
:
linger
=
4000
)
as
socket
:
sockets
=
[
stack
.
enter_context
(
make_zmq_socket
(
ctx
,
output_path
,
zmq
.
PUSH
,
linger
=
4000
))
for
output_path
in
output_paths
]
coord_socket
=
stack
.
enter_context
(
make_zmq_socket
(
ctx
,
coord_output_path
,
zmq
.
PUSH
,
bind
=
False
,
linger
=
4000
))
if
coord_output_path
is
not
None
else
None
max_reuse_bufs
=
len
(
sockets
)
+
1
while
True
:
while
True
:
outputs
=
self
.
output_queue
.
get
()
output
=
self
.
output_queue
.
get
()
if
outputs
==
EngineCoreProc
.
ENGINE_CORE_DEAD
:
if
output
==
EngineCoreProc
.
ENGINE_CORE_DEAD
:
socket
.
send
(
outputs
,
copy
=
False
)
for
socket
in
sockets
:
socket
.
send
(
output
)
break
break
assert
not
isinstance
(
outputs
,
bytes
)
assert
not
isinstance
(
output
,
bytes
)
client_index
,
outputs
=
output
outputs
.
engine_index
=
engine_index
outputs
.
engine_index
=
engine_index
if
client_index
==
-
1
:
# Don't reuse buffer for coordinator message
# which will be very small.
assert
coord_socket
is
not
None
coord_socket
.
send_multipart
(
encoder
.
encode
(
outputs
))
continue
# Reclaim buffers that zmq is finished with.
# Reclaim buffers that zmq is finished with.
while
pending
and
pending
[
-
1
][
0
].
done
:
while
pending
and
pending
[
-
1
][
0
].
done
:
reuse_buffers
.
append
(
pending
.
pop
()[
2
])
reuse_buffers
.
append
(
pending
.
pop
()[
2
])
buffer
=
reuse_buffers
.
pop
()
if
reuse_buffers
else
bytearray
()
buffer
=
reuse_buffers
.
pop
()
if
reuse_buffers
else
bytearray
()
buffers
=
encoder
.
encode_into
(
outputs
,
buffer
)
buffers
=
encoder
.
encode_into
(
outputs
,
buffer
)
tracker
=
socket
.
send_multipart
(
buffers
,
tracker
=
socket
s
[
client_index
]
.
send_multipart
(
buffers
,
copy
=
False
,
copy
=
False
,
track
=
True
)
track
=
True
)
if
not
tracker
.
done
:
if
not
tracker
.
done
:
ref
=
outputs
if
len
(
buffers
)
>
1
else
None
ref
=
outputs
if
len
(
buffers
)
>
1
else
None
pending
.
appendleft
((
tracker
,
ref
,
buffer
))
pending
.
appendleft
((
tracker
,
ref
,
buffer
))
elif
len
(
reuse_buffers
)
<
2
:
elif
len
(
reuse_buffers
)
<
max_reuse_bufs
:
#
Keep at most 2
buffers to reuse.
#
Limit the number of
buffers to reuse.
reuse_buffers
.
append
(
buffer
)
reuse_buffers
.
append
(
buffer
)
...
@@ -676,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -676,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc):
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
on_head_node
:
bool
,
input
_address
:
str
,
handshake
_address
:
str
,
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
):
):
...
@@ -691,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -691,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc):
# Counts forward-passes of the model so that we can synchronize
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
# finished with DP peers every N steps.
self
.
counter
=
0
self
.
counter
=
0
self
.
current_wave
=
0
# Initialize the engine.
# Initialize the engine.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
super
().
__init__
(
vllm_config
,
on_head_node
,
input
_address
,
super
().
__init__
(
vllm_config
,
on_head_node
,
handshake
_address
,
executor_class
,
log_stats
,
dp_rank
)
executor_class
,
log_stats
,
dp_rank
)
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
...
@@ -726,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -726,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc):
self
.
dp_rank
=
dp_rank
self
.
dp_rank
=
dp_rank
self
.
dp_group
=
vllm_config
.
parallel_config
.
stateless_init_dp_group
()
self
.
dp_group
=
vllm_config
.
parallel_config
.
stateless_init_dp_group
()
self
.
current_wave
=
0
def
shutdown
(
self
):
def
shutdown
(
self
):
super
().
shutdown
()
super
().
shutdown
()
...
@@ -734,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -734,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
stateless_destroy_torch_distributed_process_group
(
dp_group
)
def
add_request
(
self
,
request
:
EngineCoreRequest
):
def
add_request
(
self
,
request
:
EngineCoreRequest
):
if
request
.
current_wave
!=
self
.
current_wave
:
if
self
.
has_coordinator
and
request
.
current_wave
!=
self
.
current_wave
:
if
request
.
current_wave
>
self
.
current_wave
:
if
request
.
current_wave
>
self
.
current_wave
:
self
.
current_wave
=
request
.
current_wave
self
.
current_wave
=
request
.
current_wave
elif
not
self
.
engines_running
:
elif
not
self
.
engines_running
:
# Request received for an already-completed wave, notify
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
# front-end that we need to start the next one.
self
.
output_queue
.
put_nowait
(
self
.
output_queue
.
put_nowait
(
EngineCoreOutputs
(
start_wave
=
self
.
current_wave
))
(
-
1
,
EngineCoreOutputs
(
start_wave
=
self
.
current_wave
))
)
super
().
add_request
(
request
)
super
().
add_request
(
request
)
def
_handle_client_request
(
self
,
request_type
:
EngineCoreRequestType
,
def
_handle_client_request
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
request
:
Any
)
->
None
:
if
request_type
==
EngineCoreRequestType
.
START_DP_WAVE
:
if
request_type
==
EngineCoreRequestType
.
START_DP_WAVE
:
new_wave
:
int
=
request
new_wave
,
exclude_eng_index
=
request
if
new_wave
>=
self
.
current_wave
:
if
exclude_eng_index
!=
self
.
engine_index
and
(
new_wave
>=
self
.
current_wave
):
self
.
current_wave
=
new_wave
self
.
current_wave
=
new_wave
if
not
self
.
engines_running
:
if
not
self
.
engines_running
:
logger
.
debug
(
"EngineCore starting idle loop for wave %d."
,
logger
.
debug
(
"EngineCore starting idle loop for wave %d."
,
...
@@ -758,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -758,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc):
else
:
else
:
super
().
_handle_client_request
(
request_type
,
request
)
super
().
_handle_client_request
(
request_type
,
request
)
def
_maybe_publish_request_counts
(
self
):
if
not
self
.
has_coordinator
:
return
# Publish our request counts (if they've changed).
counts
=
self
.
scheduler
.
get_request_counts
()
if
counts
!=
self
.
last_counts
:
self
.
last_counts
=
counts
stats
=
SchedulerStats
(
*
counts
)
self
.
output_queue
.
put_nowait
(
(
-
1
,
EngineCoreOutputs
(
scheduler_stats
=
stats
)))
def
run_busy_loop
(
self
):
def
run_busy_loop
(
self
):
"""Core busy loop of the EngineCore for data parallel case."""
"""Core busy loop of the EngineCore for data parallel case."""
...
@@ -768,6 +844,8 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -768,6 +844,8 @@ class DPEngineCoreProc(EngineCoreProc):
# 2) Step the engine core.
# 2) Step the engine core.
executed
=
self
.
_process_engine_step
()
executed
=
self
.
_process_engine_step
()
self
.
_maybe_publish_request_counts
()
local_unfinished_reqs
=
self
.
scheduler
.
has_unfinished_requests
()
local_unfinished_reqs
=
self
.
scheduler
.
has_unfinished_requests
()
if
not
executed
:
if
not
executed
:
if
not
local_unfinished_reqs
and
not
self
.
engines_running
:
if
not
local_unfinished_reqs
and
not
self
.
engines_running
:
...
@@ -788,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -788,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc):
logger
.
debug
(
"Wave %d finished, pausing engine loop."
,
logger
.
debug
(
"Wave %d finished, pausing engine loop."
,
self
.
current_wave
)
self
.
current_wave
)
self
.
output_queue
.
put_nowait
(
self
.
output_queue
.
put_nowait
(
EngineCoreOutputs
(
wave_complete
=
self
.
current_wave
))
(
-
1
,
EngineCoreOutputs
(
wave_complete
=
self
.
current_wave
)))
self
.
current_wave
+=
1
self
.
current_wave
+=
1
def
_has_global_unfinished_reqs
(
self
,
local_unfinished
:
bool
)
->
bool
:
def
_has_global_unfinished_reqs
(
self
,
local_unfinished
:
bool
)
->
bool
:
...
...
vllm/v1/engine/core_client.py
View file @
2dbe8c07
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
asyncio
import
asyncio
import
contextlib
import
contextlib
import
queue
import
queue
import
sys
import
uuid
import
uuid
import
weakref
import
weakref
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
...
@@ -9,26 +10,28 @@ from collections import deque
...
@@ -9,26 +10,28 @@ from collections import deque
from
collections.abc
import
Awaitable
,
Sequence
from
collections.abc
import
Awaitable
,
Sequence
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
msgspec
import
msgspec
.msgpack
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.utils
import
(
get_open_port
,
get_open_zmq_inproc_path
,
from
vllm.utils
import
(
get_open_zmq_inproc_path
,
make_zmq_socket
,
get_open_zmq_ipc_path
,
get_tcp_uri
,
make_
zmq_socket
)
zmq_socket
_ctx
)
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
UtilityOutput
)
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.coordinator
import
DPCoordinator
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.utils
import
CoreEngineProcManager
from
vllm.v1.utils
import
(
CoreEngine
,
CoreEngineProcManager
,
EngineZmqAddresses
,
get_engine_client_zmq_addr
,
wait_for_engine_startup
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -36,8 +39,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
...
@@ -36,8 +39,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
STARTUP_POLL_PERIOD_MS
=
10000
class
EngineCoreClient
(
ABC
):
class
EngineCoreClient
(
ABC
):
"""
"""
...
@@ -207,7 +208,7 @@ class InprocClient(EngineCoreClient):
...
@@ -207,7 +208,7 @@ class InprocClient(EngineCoreClient):
def
get_output
(
self
)
->
EngineCoreOutputs
:
def
get_output
(
self
)
->
EngineCoreOutputs
:
outputs
,
_
=
self
.
engine_core
.
step
()
outputs
,
_
=
self
.
engine_core
.
step
()
return
outputs
return
outputs
.
get
(
0
)
or
EngineCoreOutputs
()
def
add_request
(
self
,
request
:
EngineCoreRequest
)
->
None
:
def
add_request
(
self
,
request
:
EngineCoreRequest
)
->
None
:
self
.
engine_core
.
add_request
(
request
)
self
.
engine_core
.
add_request
(
request
)
...
@@ -266,24 +267,6 @@ class InprocClient(EngineCoreClient):
...
@@ -266,24 +267,6 @@ class InprocClient(EngineCoreClient):
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
CoreEngineState
(
Enum
):
NEW
=
auto
()
CONNECTED
=
auto
()
READY
=
auto
()
class
CoreEngine
:
"""One per data parallel rank."""
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
self
.
local
=
local
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
self
.
state
=
CoreEngineState
.
NEW
self
.
num_reqs_in_flight
=
0
@
dataclass
@
dataclass
class
BackgroundResources
:
class
BackgroundResources
:
"""Used as a finalizer for clean shutdown, avoiding
"""Used as a finalizer for clean shutdown, avoiding
...
@@ -291,9 +274,12 @@ class BackgroundResources:
...
@@ -291,9 +274,12 @@ class BackgroundResources:
ctx
:
Union
[
zmq
.
Context
]
ctx
:
Union
[
zmq
.
Context
]
local_engine_manager
:
Optional
[
CoreEngineProcManager
]
=
None
local_engine_manager
:
Optional
[
CoreEngineProcManager
]
=
None
coordinator
:
Optional
[
DPCoordinator
]
=
None
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
first_req_send_socket
:
Optional
[
zmq
.
asyncio
.
Socket
]
=
None
output_queue_task
:
Optional
[
asyncio
.
Task
]
=
None
output_queue_task
:
Optional
[
asyncio
.
Task
]
=
None
stats_update_task
:
Optional
[
asyncio
.
Task
]
=
None
shutdown_path
:
Optional
[
str
]
=
None
shutdown_path
:
Optional
[
str
]
=
None
# Set if any of the engines are dead. Here so that the output
# Set if any of the engines are dead. Here so that the output
...
@@ -306,16 +292,21 @@ class BackgroundResources:
...
@@ -306,16 +292,21 @@ class BackgroundResources:
self
.
engine_dead
=
True
self
.
engine_dead
=
True
if
self
.
local_engine_manager
is
not
None
:
if
self
.
local_engine_manager
is
not
None
:
self
.
local_engine_manager
.
close
()
self
.
local_engine_manager
.
close
()
if
self
.
coordinator
is
not
None
:
self
.
coordinator
.
close
()
if
self
.
output_queue_task
is
not
None
:
if
self
.
output_queue_task
is
not
None
:
self
.
output_queue_task
.
cancel
()
self
.
output_queue_task
.
cancel
()
if
self
.
stats_update_task
is
not
None
:
self
.
stats_update_task
.
cancel
()
# ZMQ context termination can hang if the sockets
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
# aren't explicitly closed first.
if
self
.
output_socket
is
not
None
:
for
socket
in
(
self
.
output_socket
,
self
.
input_socket
,
self
.
output_socket
.
close
(
linger
=
0
)
self
.
first_req_send_socket
):
if
self
.
input_socket
is
not
None
:
if
socket
is
not
None
:
self
.
input_socket
.
close
(
linger
=
0
)
socket
.
close
(
linger
=
0
)
if
self
.
shutdown_path
is
not
None
:
if
self
.
shutdown_path
is
not
None
:
# We must ensure that the sync output socket is
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
# closed cleanly in its own thread.
...
@@ -350,6 +341,7 @@ class MPClient(EngineCoreClient):
...
@@ -350,6 +341,7 @@ class MPClient(EngineCoreClient):
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
):
):
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
# Serialization setup.
# Serialization setup.
...
@@ -369,8 +361,8 @@ class MPClient(EngineCoreClient):
...
@@ -369,8 +361,8 @@ class MPClient(EngineCoreClient):
try
:
try
:
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
local_engine_count
=
parallel_config
.
data_parallel_size_local
local_engine_count
=
parallel_config
.
data_parallel_size_local
start_index
=
parallel_config
.
data_parallel_rank
local_start_index
=
parallel_config
.
data_parallel_rank_local
local_start_index
=
parallel_config
.
data_parallel_rank_local
dp_size
=
parallel_config
.
data_parallel_size
# SPMD mode is where there is an LLM instance per DP rank and
# SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see
# one core engine per LLM, see
...
@@ -382,42 +374,53 @@ class MPClient(EngineCoreClient):
...
@@ -382,42 +374,53 @@ class MPClient(EngineCoreClient):
CoreEngine
(
index
=
local_start_index
,
local
=
True
)
CoreEngine
(
index
=
local_start_index
,
local
=
True
)
]
]
else
:
else
:
assert
start_index
==
0
assert
parallel_config
.
data_parallel_rank
==
0
local_start_index
=
0
local_start_index
=
0
self
.
core_engines
=
[
self
.
core_engines
=
[
CoreEngine
(
index
=
i
,
local
=
(
i
<
local_engine_count
))
CoreEngine
(
index
=
i
,
local
=
(
i
<
local_engine_count
))
for
i
in
range
(
p
arallel_config
.
data_parallel
_size
)
for
i
in
range
(
d
p_size
)
]
]
input_address
,
output_address
=
self
.
_get_zmq_addresses
(
local_only
=
spmd_mode
or
local_engine_count
==
dp_size
parallel_config
,
spmd_mode
)
self
.
stats_update_address
:
Optional
[
str
]
=
None
if
client_addresses
is
not
None
:
input_address
=
client_addresses
[
"input_address"
]
output_address
=
client_addresses
[
"output_address"
]
self
.
stats_update_address
=
client_addresses
.
get
(
"stats_update_address"
)
else
:
host
=
parallel_config
.
data_parallel_master_ip
input_address
=
get_engine_client_zmq_addr
(
local_only
,
host
)
output_address
=
get_engine_client_zmq_addr
(
local_only
,
host
)
# Create input and output sockets.
# Create input and output sockets.
self
.
input_socket
=
self
.
resources
.
input_socket
=
make_zmq_socket
(
self
.
input_socket
=
self
.
resources
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
input_address
,
zmq
.
ROUTER
,
bind
=
True
)
self
.
ctx
,
input_address
,
zmq
.
ROUTER
,
bind
=
True
)
self
.
resources
.
output_socket
=
make_zmq_socket
(
self
.
resources
.
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_address
,
zmq
.
constants
.
PULL
)
self
.
ctx
,
output_address
,
zmq
.
PULL
)
# Start local engines.
if
local_engine_count
:
if
client_addresses
is
None
:
# In server mode, start_index and local_start_index will
self
.
_init_engines_direct
(
vllm_config
,
local_only
,
# both be 0.
local_start_index
,
input_address
,
self
.
resources
.
local_engine_manager
=
CoreEngineProcManager
(
output_address
,
executor_class
,
EngineCoreProc
.
run_engine_core
,
log_stats
)
vllm_config
=
vllm_config
,
coordinator
=
self
.
resources
.
coordinator
executor_class
=
executor_class
,
if
coordinator
:
log_stats
=
log_stats
,
self
.
stats_update_address
=
(
input_address
=
input_address
,
coordinator
.
get_stats_publish_address
())
on_head_node
=
True
,
local_engine_count
=
local_engine_count
,
# Wait for ready messages from each engine on the input socket.
start_index
=
start_index
,
identities
=
set
(
e
.
identity
for
e
in
self
.
core_engines
)
local_start_index
=
local_start_index
)
sync_input_socket
=
zmq
.
Socket
.
shadow
(
self
.
input_socket
)
while
identities
:
if
not
sync_input_socket
.
poll
(
timeout
=
600_000
):
raise
TimeoutError
(
"Timed out waiting for engines to send"
"initial message on input socket."
)
identity
,
_
=
sync_input_socket
.
recv_multipart
()
identities
.
remove
(
identity
)
self
.
core_engine
=
self
.
core_engines
[
0
]
self
.
core_engine
=
self
.
core_engines
[
0
]
# Wait for engine core process(es) to start.
self
.
_wait_for_engine_startup
(
output_address
,
parallel_config
)
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
# Request objects which may contain pytorch-allocated tensors
# Request objects which may contain pytorch-allocated tensors
...
@@ -430,116 +433,67 @@ class MPClient(EngineCoreClient):
...
@@ -430,116 +433,67 @@ class MPClient(EngineCoreClient):
if
not
success
:
if
not
success
:
self
.
_finalizer
()
self
.
_finalizer
()
@
staticmethod
def
_init_engines_direct
(
self
,
vllm_config
:
VllmConfig
,
local_only
:
bool
,
def
_get_zmq_addresses
(
parallel_config
:
ParallelConfig
,
local_start_index
:
int
,
input_address
:
str
,
spmd_mode
:
bool
)
->
tuple
[
str
,
str
]:
output_address
:
str
,
"""Returns (input_address, output_address)."""
executor_class
:
type
[
Executor
],
log_stats
:
bool
):
dp_size
=
parallel_config
.
data_parallel_size
"""Self-contained client mode, launch engine and coordinator process
as needed."""
parallel_config
=
vllm_config
.
parallel_config
local_engine_count
=
parallel_config
.
data_parallel_size_local
local_engine_count
=
parallel_config
.
data_parallel_size_local
start_index
=
parallel_config
.
data_parallel_rank
host
=
parallel_config
.
data_parallel_master_ip
if
local_engine_count
==
dp_size
or
spmd_mode
:
if
len
(
self
.
core_engines
)
>
1
:
input_address
=
get_open_zmq_ipc_path
()
self
.
resources
.
coordinator
=
DPCoordinator
(
parallel_config
)
output_address
=
get_open_zmq_ipc_path
()
else
:
handshake_address
=
get_engine_client_zmq_addr
(
host
=
parallel_config
.
data_parallel_master_ip
local_only
,
host
,
parallel_config
.
data_parallel_rpc_port
)
input_port
=
parallel_config
.
data_parallel_rpc_port
output_port
=
get_open_port
()
input_address
=
get_tcp_uri
(
host
,
input_port
)
output_address
=
get_tcp_uri
(
host
,
output_port
)
return
input_address
,
output_address
def
_wait_for_engine_startup
(
self
,
output_address
:
str
,
parallel_config
:
ParallelConfig
):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket
=
zmq
.
Socket
.
shadow
(
self
.
input_socket
)
# Wait for engine core process(es) to send ready messages.
local_count
=
parallel_config
.
data_parallel_size_local
remote_count
=
len
(
self
.
core_engines
)
-
local_count
# [local, remote] counts
conn_pending
,
start_pending
=
[
local_count
,
remote_count
],
[
0
,
0
]
poller
=
zmq
.
Poller
()
poller
.
register
(
sync_input_socket
,
zmq
.
POLLIN
)
proc_manager
=
self
.
resources
.
local_engine_manager
if
proc_manager
is
not
None
:
for
sentinel
in
proc_manager
.
sentinels
():
poller
.
register
(
sentinel
,
zmq
.
POLLIN
)
while
any
(
conn_pending
)
or
any
(
start_pending
):
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
if
not
events
:
if
any
(
conn_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect."
,
*
conn_pending
)
if
any
(
start_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to start."
,
*
start_pending
)
continue
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
sync_input_socket
:
# One of the local core processes exited.
finished
=
proc_manager
.
finished_procs
(
)
if
proc_manager
else
{}
raise
RuntimeError
(
"Engine core initialization failed. "
"See root cause above. "
f
"Failed core proc(s):
{
finished
}
"
)
# Receive HELLO and READY messages from the input socket.
eng_identity
,
ready_msg_bytes
=
sync_input_socket
.
recv_multipart
()
eng_index
=
int
.
from_bytes
(
eng_identity
,
byteorder
=
"little"
)
engine
=
next
(
(
e
for
e
in
self
.
core_engines
if
e
.
identity
==
eng_identity
),
None
)
if
engine
is
None
:
raise
RuntimeError
(
f
"Message from engine with unexpected data "
f
"parallel rank:
{
eng_index
}
"
)
msg
=
msgspec
.
msgpack
.
decode
(
ready_msg_bytes
)
status
,
local
=
msg
[
"status"
],
msg
[
"local"
]
if
local
!=
engine
.
local
:
raise
RuntimeError
(
f
"
{
status
}
message from "
f
"
{
'local'
if
local
else
'remote'
}
"
f
"engine
{
eng_index
}
, expected it to be "
f
"
{
'local'
if
engine
.
local
else
'remote'
}
"
)
if
status
==
"HELLO"
and
engine
.
state
==
CoreEngineState
.
NEW
:
# Send init message with DP config info.
init_message
=
self
.
encoder
.
encode
({
"output_socket_address"
:
output_address
,
"parallel_config"
:
{
"data_parallel_master_ip"
:
parallel_config
.
data_parallel_master_ip
,
"data_parallel_master_port"
:
parallel_config
.
data_parallel_master_port
,
"data_parallel_size"
:
parallel_config
.
data_parallel_size
,
},
})
sync_input_socket
.
send_multipart
((
eng_identity
,
*
init_message
),
copy
=
False
)
conn_pending
[
0
if
local
else
1
]
-=
1
start_pending
[
0
if
local
else
1
]
+=
1
engine
.
state
=
CoreEngineState
.
CONNECTED
elif
status
==
"READY"
and
(
engine
.
state
==
CoreEngineState
.
CONNECTED
):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
cache_config
=
self
.
vllm_config
.
cache_config
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
or
0
num_gpu_blocks
+=
msg
[
'num_gpu_blocks'
]
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
start_pending
[
0
if
local
else
1
]
-=
1
engine
.
state
=
CoreEngineState
.
READY
else
:
raise
RuntimeError
(
f
"Unexpected
{
status
}
message for "
f
"
{
'local'
if
local
else
'remote'
}
engine "
f
"
{
eng_index
}
in
{
engine
.
state
}
state."
)
logger
.
debug
(
"%s from %s core engine process %s."
,
status
,
with
zmq_socket_ctx
(
handshake_address
,
zmq
.
ROUTER
,
"local"
if
local
else
"remote"
,
eng_index
)
bind
=
True
)
as
handshake_socket
:
# Start local engines.
if
local_engine_count
:
# In server mode, start_index and local_start_index will
# both be 0.
self
.
resources
.
local_engine_manager
=
CoreEngineProcManager
(
EngineCoreProc
.
run_engine_core
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
handshake_address
=
handshake_address
,
on_head_node
=
True
,
local_engine_count
=
local_engine_count
,
start_index
=
start_index
,
local_start_index
=
local_start_index
)
# Wait for engine core process(es) to start.
self
.
_wait_for_engine_startup
(
handshake_socket
,
input_address
,
output_address
)
def
_wait_for_engine_startup
(
self
,
handshake_socket
:
zmq
.
Socket
,
input_address
:
str
,
output_address
:
str
):
addresses
=
EngineZmqAddresses
(
inputs
=
[
input_address
],
outputs
=
[
output_address
],
)
coordinator
=
self
.
resources
.
coordinator
if
coordinator
is
not
None
:
addresses
.
coordinator_input
,
addresses
.
coordinator_output
=
(
coordinator
.
get_engine_socket_addresses
())
wait_for_engine_startup
(
handshake_socket
,
addresses
,
self
.
core_engines
,
self
.
vllm_config
.
parallel_config
,
self
.
vllm_config
.
cache_config
,
self
.
resources
.
local_engine_manager
,
coordinator
.
proc
if
coordinator
else
None
,
)
def
shutdown
(
self
):
def
shutdown
(
self
):
# Terminate background resources.
# Terminate background resources.
...
@@ -605,8 +559,8 @@ class SyncMPClient(MPClient):
...
@@ -605,8 +559,8 @@ class SyncMPClient(MPClient):
try
:
try
:
shutdown_socket
.
bind
(
shutdown_path
)
shutdown_socket
.
bind
(
shutdown_path
)
poller
=
zmq
.
Poller
()
poller
=
zmq
.
Poller
()
poller
.
register
(
shutdown_socket
)
poller
.
register
(
shutdown_socket
,
zmq
.
POLLIN
)
poller
.
register
(
out_socket
)
poller
.
register
(
out_socket
,
zmq
.
POLLIN
)
while
True
:
while
True
:
socks
=
poller
.
poll
()
socks
=
poller
.
poll
()
if
not
socks
:
if
not
socks
:
...
@@ -668,7 +622,7 @@ class SyncMPClient(MPClient):
...
@@ -668,7 +622,7 @@ class SyncMPClient(MPClient):
future
:
Future
[
Any
]
=
Future
()
future
:
Future
[
Any
]
=
Future
()
self
.
utility_results
[
call_id
]
=
future
self
.
utility_results
[
call_id
]
=
future
self
.
_send_input
(
EngineCoreRequestType
.
UTILITY
,
self
.
_send_input
(
EngineCoreRequestType
.
UTILITY
,
(
call_id
,
method
,
args
))
(
0
,
call_id
,
method
,
args
))
return
future
.
result
()
return
future
.
result
()
...
@@ -730,15 +684,21 @@ class SyncMPClient(MPClient):
...
@@ -730,15 +684,21 @@ class SyncMPClient(MPClient):
class
AsyncMPClient
(
MPClient
):
class
AsyncMPClient
(
MPClient
):
"""Asyncio-compatible client for multi-proc EngineCore."""
"""Asyncio-compatible client for multi-proc EngineCore."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
def
__init__
(
self
,
log_stats
:
bool
):
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
):
super
().
__init__
(
super
().
__init__
(
asyncio_mode
=
True
,
asyncio_mode
=
True
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
log_stats
=
log_stats
,
client_addresses
=
client_addresses
,
)
)
self
.
client_index
=
client_index
self
.
outputs_queue
=
asyncio
.
Queue
[
Union
[
EngineCoreOutputs
,
self
.
outputs_queue
=
asyncio
.
Queue
[
Union
[
EngineCoreOutputs
,
Exception
]]()
Exception
]]()
try
:
try
:
...
@@ -854,12 +814,13 @@ class AsyncMPClient(MPClient):
...
@@ -854,12 +814,13 @@ class AsyncMPClient(MPClient):
future
=
asyncio
.
get_running_loop
().
create_future
()
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
utility_results
[
call_id
]
=
future
self
.
utility_results
[
call_id
]
=
future
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
*
self
.
encoder
.
encode
(
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
*
self
.
encoder
.
encode
(
(
call_id
,
method
,
args
)))
(
self
.
client_index
,
call_id
,
method
,
args
)))
await
self
.
_send_input_message
(
message
,
engine
,
args
)
await
self
.
_send_input_message
(
message
,
engine
,
args
)
self
.
_ensure_output_queue_task
()
self
.
_ensure_output_queue_task
()
return
await
future
return
await
future
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
request
.
client_index
=
self
.
client_index
await
self
.
_send_input
(
EngineCoreRequestType
.
ADD
,
request
)
await
self
.
_send_input
(
EngineCoreRequestType
.
ADD
,
request
)
self
.
_ensure_output_queue_task
()
self
.
_ensure_output_queue_task
()
...
@@ -921,17 +882,120 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -921,17 +882,120 @@ class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore."""
EngineCore."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
def
__init__
(
self
,
log_stats
:
bool
):
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
):
self
.
current_wave
=
0
self
.
current_wave
=
0
self
.
engines_running
=
False
self
.
engines_running
=
False
# To route aborts to the correct engine.
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
)
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
client_addresses
,
client_index
)
assert
len
(
self
.
core_engines
)
>
1
assert
len
(
self
.
core_engines
)
>
1
# List of [waiting, running] pair per engine.
self
.
lb_engines
:
list
[
list
[
int
]]
=
[]
self
.
first_req_sock_addr
=
get_open_zmq_inproc_path
()
self
.
first_req_send_socket
=
self
.
resources
.
first_req_send_socket
=
(
make_zmq_socket
(
self
.
ctx
,
self
.
first_req_sock_addr
,
zmq
.
PAIR
,
bind
=
True
))
try
:
# If we are running in an asyncio event loop, start the stats task.
# Otherwise, it will be started lazily.
asyncio
.
get_running_loop
()
self
.
_ensure_stats_update_task
()
except
RuntimeError
:
pass
def
_ensure_stats_update_task
(
self
):
resources
=
self
.
resources
if
resources
.
stats_update_task
is
not
None
:
return
assert
self
.
stats_update_address
is
not
None
async
def
run_engine_stats_update_task
():
with
make_zmq_socket
(
self
.
ctx
,
self
.
stats_update_address
,
zmq
.
XSUB
)
as
socket
,
make_zmq_socket
(
self
.
ctx
,
self
.
first_req_sock_addr
,
zmq
.
PAIR
,
bind
=
False
)
as
first_req_rcv_socket
:
# Send subscription message.
await
socket
.
send
(
b
'
\x01
'
)
poller
=
zmq
.
asyncio
.
Poller
()
poller
.
register
(
socket
,
zmq
.
POLLIN
)
poller
.
register
(
first_req_rcv_socket
,
zmq
.
POLLIN
)
while
True
:
events
=
await
poller
.
poll
()
if
not
self
.
engines_running
and
len
(
events
)
==
2
or
(
events
[
0
][
0
]
==
first_req_rcv_socket
):
# Send a message to notify the coordinator that
# we're sending a request while the engines are
# paused, so that it can wake the others up
# (to run dummy EP loop).
self
.
engines_running
=
True
buf
=
first_req_rcv_socket
.
recv
(
flags
=
zmq
.
NOBLOCK
).
result
()
target_eng_index
=
int
.
from_bytes
(
buf
,
"little"
)
msg
=
msgspec
.
msgpack
.
encode
(
(
target_eng_index
,
self
.
current_wave
))
await
socket
.
send
(
msg
)
buf
=
None
while
True
:
# Drain all stats events (we only care about latest).
future
:
asyncio
.
Future
[
bytes
]
=
socket
.
recv
(
flags
=
zmq
.
NOBLOCK
)
if
isinstance
(
future
.
exception
(),
zmq
.
Again
):
break
buf
=
future
.
result
()
if
buf
is
None
:
continue
# Update local load-balancing state.
counts
,
wave
,
running
=
msgspec
.
msgpack
.
decode
(
buf
)
self
.
current_wave
=
wave
self
.
engines_running
=
running
self
.
lb_engines
=
counts
resources
.
stats_update_task
=
asyncio
.
create_task
(
run_engine_stats_update_task
())
def
get_core_engine_for_request
(
self
)
->
CoreEngine
:
if
not
self
.
lb_engines
:
return
self
.
core_engines
[
0
]
# TODO use P2C alg for larger DP sizes
num_engines
=
len
(
self
.
lb_engines
)
min_counts
=
[
sys
.
maxsize
,
sys
.
maxsize
]
eng_index
=
0
for
i
in
range
(
num_engines
):
# Start from client_index to help with balancing when engines
# are empty.
idx
=
(
self
.
client_index
+
i
)
%
num_engines
counts
=
self
.
lb_engines
[
idx
]
if
counts
<
min_counts
:
min_counts
=
counts
eng_index
=
idx
# Adjust local counts for better balancing between stats updates
# from the coordinator (which happen every 100ms).
if
min_counts
[
0
]:
min_counts
[
0
]
+=
1
else
:
min_counts
[
1
]
+=
1
return
self
.
core_engines
[
eng_index
]
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
# Only the result from the first engine is returned.
# Only the result from the first engine is returned.
return
(
await
asyncio
.
gather
(
*
[
return
(
await
asyncio
.
gather
(
*
[
...
@@ -940,62 +1004,30 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -940,62 +1004,30 @@ class DPAsyncMPClient(AsyncMPClient):
]))[
0
]
]))[
0
]
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
self
.
_ensure_stats_update_task
()
request
.
current_wave
=
self
.
current_wave
request
.
current_wave
=
self
.
current_wave
request
.
client_index
=
self
.
client_index
chosen_engine
=
self
.
get_core_engine_for_request
()
chosen_engine
=
self
.
get_core_engine_for_request
()
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
chosen_engine
.
num_reqs_in_flight
+=
1
to_await
=
self
.
_send_input
(
EngineCoreRequestType
.
ADD
,
request
,
to_await
=
self
.
_send_input
(
EngineCoreRequestType
.
ADD
,
request
,
chosen_engine
)
chosen_engine
)
if
not
self
.
engines_running
:
if
not
self
.
engines_running
:
# Send request to chosen engine and dp start loop
# Notify coordinator that we're sending a request
# control message to all other engines.
await
self
.
first_req_send_socket
.
send
(
chosen_engine
.
identity
)
self
.
engines_running
=
True
to_await
=
asyncio
.
gather
(
to_await
,
# type: ignore[assignment]
*
self
.
_start_wave_coros
(
exclude_index
=
chosen_engine
.
index
))
await
to_await
await
to_await
self
.
_ensure_output_queue_task
()
self
.
_ensure_output_queue_task
()
def
get_core_engine_for_request
(
self
)
->
CoreEngine
:
return
min
(
self
.
core_engines
,
key
=
lambda
e
:
e
.
num_reqs_in_flight
)
@
staticmethod
@
staticmethod
async
def
process_engine_outputs
(
self
:
"DPAsyncMPClient"
,
async
def
process_engine_outputs
(
self
:
"DPAsyncMPClient"
,
outputs
:
EngineCoreOutputs
):
outputs
:
EngineCoreOutputs
):
if
self
.
reqs_in_flight
:
if
outputs
.
finished_requests
and
self
.
reqs_in_flight
:
for
req_id
in
outputs
.
finished_requests
or
():
for
req_id
in
outputs
.
finished_requests
:
if
engine
:
=
self
.
reqs_in_flight
.
pop
(
req_id
,
None
):
self
.
reqs_in_flight
.
pop
(
req_id
,
None
)
engine
.
num_reqs_in_flight
-=
1
if
outputs
.
wave_complete
is
not
None
:
# Current wave is complete, move to next wave number
# and mark engines as paused.
if
self
.
current_wave
<=
outputs
.
wave_complete
:
self
.
current_wave
=
outputs
.
wave_complete
+
1
self
.
engines_running
=
False
elif
outputs
.
start_wave
is
not
None
and
(
outputs
.
start_wave
>
self
.
current_wave
or
(
outputs
.
start_wave
==
self
.
current_wave
and
not
self
.
engines_running
)):
# Engine received request for a non-current wave so we must ensure
# that other engines progress to the next wave.
self
.
current_wave
=
outputs
.
start_wave
self
.
engines_running
=
True
await
asyncio
.
gather
(
*
self
.
_start_wave_coros
(
exclude_index
=
outputs
.
engine_index
))
def
_start_wave_coros
(
self
,
exclude_index
:
int
)
->
list
[
Awaitable
[
None
]]:
logger
.
debug
(
"Sending start DP wave %d."
,
self
.
current_wave
)
return
[
self
.
_send_input
(
EngineCoreRequestType
.
START_DP_WAVE
,
self
.
current_wave
,
engine
)
for
engine
in
self
.
core_engines
if
engine
.
index
!=
exclude_index
]
async
def
abort_requests_async
(
self
,
request_ids
:
list
[
str
])
->
None
:
async
def
abort_requests_async
(
self
,
request_ids
:
list
[
str
])
->
None
:
if
not
request_ids
:
if
not
request_ids
:
...
...
vllm/v1/metrics/loggers.py
View file @
2dbe8c07
...
@@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig
...
@@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.prometheus
import
unregister_vllm_metrics
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingLogging
,
SpecDecodingProm
from
vllm.v1.spec_decode.metrics
import
SpecDecodingLogging
,
SpecDecodingProm
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5.0
StatLoggerFactory
=
Callable
[[
VllmConfig
,
int
],
"StatLoggerBase"
]
StatLoggerFactory
=
Callable
[[
VllmConfig
,
int
],
"StatLoggerBase"
]
...
@@ -35,7 +34,7 @@ class StatLoggerBase(ABC):
...
@@ -35,7 +34,7 @@ class StatLoggerBase(ABC):
...
...
@
abstractmethod
@
abstractmethod
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
iteration_stats
:
Optional
[
IterationStats
]):
...
...
...
@@ -78,20 +77,22 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -78,20 +77,22 @@ class LoggingStatLogger(StatLoggerBase):
# Compute summary metrics for tracked stats
# Compute summary metrics for tracked stats
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
self
.
last_log_time
))
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
self
.
last_log_time
))
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
iteration_stats
:
Optional
[
IterationStats
]):
"""Log Stats to standard output."""
"""Log Stats to standard output."""
if
iteration_stats
:
if
iteration_stats
:
self
.
_track_iteration_stats
(
iteration_stats
)
self
.
_track_iteration_stats
(
iteration_stats
)
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
if
scheduler_stats
is
not
None
:
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_logging
.
observe
(
self
.
spec_decoding_logging
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
scheduler_stats
.
spec_decoding_stats
)
self
.
last_scheduler_stats
=
scheduler_stats
self
.
last_scheduler_stats
=
scheduler_stats
def
log
(
self
):
def
log
(
self
):
now
=
time
.
monotonic
()
now
=
time
.
monotonic
()
...
@@ -131,10 +132,11 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -131,10 +132,11 @@ class LoggingStatLogger(StatLoggerBase):
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
def
log_engine_initialized
(
self
):
def
log_engine_initialized
(
self
):
logger
.
info
(
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
"vllm cache_config_info with initialization "
\
logger
.
info
(
"after num_gpu_blocks is: %d"
,
"Engine %03d: vllm cache_config_info with initialization "
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
)
"after num_gpu_blocks is: %d"
,
self
.
engine_index
,
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
)
class
PrometheusStatLogger
(
StatLoggerBase
):
class
PrometheusStatLogger
(
StatLoggerBase
):
...
@@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase):
_spec_decoding_cls
=
SpecDecodingProm
_spec_decoding_cls
=
SpecDecodingProm
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
self
.
_unregister_vllm_metrics
()
unregister_vllm_metrics
()
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
engine_index
=
engine_index
self
.
engine_index
=
engine_index
# Use this flag to hide metrics that were deprecated in
# Use this flag to hide metrics that were deprecated in
...
@@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_scheduler_running
=
self
.
_gauge_cls
(
self
.
gauge_scheduler_running
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_running"
,
name
=
"vllm:num_requests_running"
,
documentation
=
"Number of requests in model execution batches."
,
documentation
=
"Number of requests in model execution batches."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
gauge_scheduler_waiting
=
self
.
_gauge_cls
(
self
.
gauge_scheduler_waiting
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_waiting"
,
name
=
"vllm:num_requests_waiting"
,
documentation
=
"Number of requests waiting to be processed."
,
documentation
=
"Number of requests waiting to be processed."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
#
#
...
@@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_gpu_cache_usage
=
self
.
_gauge_cls
(
self
.
gauge_gpu_cache_usage
=
self
.
_gauge_cls
(
name
=
"vllm:gpu_cache_usage_perc"
,
name
=
"vllm:gpu_cache_usage_perc"
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_gpu_prefix_cache_queries
=
self
.
_counter_cls
(
self
.
counter_gpu_prefix_cache_queries
=
self
.
_counter_cls
(
...
@@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase):
buckets
=
build_1_2_5_buckets
(
max_model_len
),
buckets
=
build_1_2_5_buckets
(
max_model_len
),
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
# See: https://github.com/vllm-project/vllm/pull/18053
self
.
histogram_iteration_tokens
=
\
self
.
histogram_iteration_tokens
=
\
self
.
_histogram_cls
(
self
.
_histogram_cls
(
name
=
"vllm:iteration_tokens_total"
,
name
=
"vllm:iteration_tokens_total"
,
...
@@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase):
#
#
# LoRA metrics
# LoRA metrics
#
#
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
self
.
gauge_lora_info
:
Optional
[
prometheus_client
.
Gauge
]
=
None
self
.
gauge_lora_info
:
Optional
[
prometheus_client
.
Gauge
]
=
None
if
vllm_config
.
lora_config
is
not
None
:
if
vllm_config
.
lora_config
is
not
None
:
self
.
labelname_max_lora
=
"max_lora"
self
.
labelname_max_lora
=
"max_lora"
...
@@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
_gauge_cls
(
self
.
_gauge_cls
(
name
=
"vllm:lora_requests_info"
,
name
=
"vllm:lora_requests_info"
,
documentation
=
"Running stats on lora requests."
,
documentation
=
"Running stats on lora requests."
,
multiprocess_mode
=
"sum"
,
labelnames
=
[
labelnames
=
[
self
.
labelname_max_lora
,
self
.
labelname_max_lora
,
self
.
labelname_waiting_lora_adapters
,
self
.
labelname_waiting_lora_adapters
,
self
.
labelname_running_lora_adapters
,
self
.
labelname_running_lora_adapters
,
])
],
)
def
log_metrics_info
(
self
,
type
:
str
,
config_obj
:
SupportsMetricsInfo
):
def
log_metrics_info
(
self
,
type
:
str
,
config_obj
:
SupportsMetricsInfo
):
metrics_info
=
config_obj
.
metrics_info
()
metrics_info
=
config_obj
.
metrics_info
()
metrics_info
[
"engine"
]
=
self
.
engine_index
metrics_info
[
"engine"
]
=
self
.
engine_index
...
@@ -372,25 +387,28 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -372,25 +387,28 @@ class PrometheusStatLogger(StatLoggerBase):
info_gauge
=
self
.
_gauge_cls
(
info_gauge
=
self
.
_gauge_cls
(
name
=
name
,
name
=
name
,
documentation
=
documentation
,
documentation
=
documentation
,
labelnames
=
metrics_info
.
keys
()).
labels
(
**
metrics_info
)
multiprocess_mode
=
"mostrecent"
,
labelnames
=
metrics_info
.
keys
(),
).
labels
(
**
metrics_info
)
info_gauge
.
set
(
1
)
info_gauge
.
set
(
1
)
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
iteration_stats
:
Optional
[
IterationStats
]):
"""Log to prometheus."""
"""Log to prometheus."""
self
.
gauge_scheduler_running
.
set
(
scheduler_stats
.
num_running_reqs
)
if
scheduler_stats
is
not
None
:
self
.
gauge_scheduler_waiting
.
set
(
scheduler_stats
.
num_waiting_reqs
)
self
.
gauge_scheduler_running
.
set
(
scheduler_stats
.
num_running_reqs
)
self
.
gauge_scheduler_waiting
.
set
(
scheduler_stats
.
num_waiting_reqs
)
self
.
gauge_gpu_cache_usage
.
set
(
scheduler_stats
.
gpu_cache_usage
)
self
.
gauge_gpu_cache_usage
.
set
(
scheduler_stats
.
gpu_cache_usage
)
self
.
counter_gpu_prefix_cache_queries
.
inc
(
self
.
counter_gpu_prefix_cache_queries
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
queries
)
scheduler_stats
.
prefix_cache_stats
.
queries
)
self
.
counter_gpu_prefix_cache_hits
.
inc
(
self
.
counter_gpu_prefix_cache_hits
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
hits
)
scheduler_stats
.
prefix_cache_stats
.
hits
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_prom
.
observe
(
self
.
spec_decoding_prom
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
scheduler_stats
.
spec_decoding_stats
)
if
iteration_stats
is
None
:
if
iteration_stats
is
None
:
return
return
...
@@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_lora_info
.
labels
(
**
lora_info_labels
)
\
self
.
gauge_lora_info
.
labels
(
**
lora_info_labels
)
\
.
set_to_current_time
()
.
set_to_current_time
()
@
staticmethod
def
_unregister_vllm_metrics
():
# Unregister any existing vLLM collectors (for CI/CD
for
collector
in
list
(
prometheus_client
.
REGISTRY
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
prometheus_client
.
REGISTRY
.
unregister
(
collector
)
def
log_engine_initialized
(
self
):
def
log_engine_initialized
(
self
):
self
.
log_metrics_info
(
"cache_config"
,
self
.
vllm_config
.
cache_config
)
self
.
log_metrics_info
(
"cache_config"
,
self
.
vllm_config
.
cache_config
)
...
...
vllm/v1/metrics/prometheus.py
0 → 100644
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
os
import
tempfile
from
typing
import
Optional
from
prometheus_client
import
REGISTRY
,
CollectorRegistry
,
multiprocess
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
# Global temporary directory for prometheus multiprocessing
_prometheus_multiproc_dir
:
Optional
[
tempfile
.
TemporaryDirectory
]
=
None
def
setup_multiprocess_prometheus
():
"""Set up prometheus multiprocessing directory if not already configured.
"""
global
_prometheus_multiproc_dir
if
"PROMETHEUS_MULTIPROC_DIR"
not
in
os
.
environ
:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
_prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
()
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
=
_prometheus_multiproc_dir
.
name
logger
.
debug
(
"Created PROMETHEUS_MULTIPROC_DIR at %s"
,
_prometheus_multiproc_dir
.
name
)
else
:
logger
.
warning
(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup."
)
def
get_prometheus_registry
():
"""Get the appropriate prometheus registry based on multiprocessing
configuration.
Returns:
Registry: A prometheus registry
"""
if
os
.
getenv
(
"PROMETHEUS_MULTIPROC_DIR"
)
is
not
None
:
logger
.
debug
(
"Using multiprocess registry for prometheus metrics"
)
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
return
registry
return
REGISTRY
def
unregister_vllm_metrics
():
"""Unregister any existing vLLM collectors from the prometheus registry.
This is useful for testing and CI/CD where metrics may be registered
multiple times across test runs.
Also, in case of multiprocess, we need to unregister the metrics from the
global registry.
"""
registry
=
REGISTRY
# Unregister any existing vLLM collectors
for
collector
in
list
(
registry
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
registry
.
unregister
(
collector
)
def
shutdown_prometheus
():
"""Shutdown prometheus metrics."""
try
:
pid
=
os
.
getpid
()
multiprocess
.
mark_process_dead
(
pid
)
logger
.
debug
(
"Marked Prometheus metrics for process %d as dead"
,
pid
)
except
Exception
as
e
:
logger
.
error
(
"Error during metrics cleanup: %s"
,
str
(
e
))
vllm/v1/request.py
View file @
2dbe8c07
...
@@ -26,12 +26,13 @@ class Request:
...
@@ -26,12 +26,13 @@ class Request:
multi_modal_placeholders
:
Optional
[
list
[
PlaceholderRange
]],
multi_modal_placeholders
:
Optional
[
list
[
PlaceholderRange
]],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
eos_token_id
:
Optional
[
int
],
eos_token_id
:
Optional
[
int
],
arrival_time
:
float
,
client_index
:
int
=
0
,
lora_request
:
Optional
[
"LoRARequest"
]
=
None
,
lora_request
:
Optional
[
"LoRARequest"
]
=
None
,
structured_output_request
:
Optional
[
"StructuredOutputRequest"
]
=
None
,
structured_output_request
:
Optional
[
"StructuredOutputRequest"
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
client_index
=
client_index
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
# Because of LoRA, the eos token id can be different for each request.
# Because of LoRA, the eos token id can be different for each request.
self
.
eos_token_id
=
eos_token_id
self
.
eos_token_id
=
eos_token_id
...
@@ -90,13 +91,13 @@ class Request:
...
@@ -90,13 +91,13 @@ class Request:
return
cls
(
return
cls
(
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
client_index
=
request
.
client_index
,
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt_token_ids
=
request
.
prompt_token_ids
,
multi_modal_inputs
=
request
.
mm_inputs
,
multi_modal_inputs
=
request
.
mm_inputs
,
multi_modal_hashes
=
request
.
mm_hashes
,
multi_modal_hashes
=
request
.
mm_hashes
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
sampling_params
=
request
.
sampling_params
,
sampling_params
=
request
.
sampling_params
,
eos_token_id
=
request
.
eos_token_id
,
eos_token_id
=
request
.
eos_token_id
,
arrival_time
=
request
.
arrival_time
,
lora_request
=
request
.
lora_request
,
lora_request
=
request
.
lora_request
,
structured_output_request
=
StructuredOutputRequest
(
structured_output_request
=
StructuredOutputRequest
(
sampling_params
=
request
.
sampling_params
),
sampling_params
=
request
.
sampling_params
),
...
...
vllm/v1/utils.py
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
import
argparse
import
multiprocessing
import
time
import
time
import
weakref
import
weakref
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
multiprocessing
import
Process
,
connection
from
multiprocessing
import
Process
,
connection
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
from
multiprocessing.process
import
BaseProcess
overload
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
)
import
msgspec
import
torch
import
torch
import
zmq
from
vllm.config
import
VllmConfig
from
vllm.config
import
CacheConfig
,
ParallelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.utils
import
get_mp_context
,
kill_process_tree
from
vllm.utils
import
(
get_mp_context
,
get_open_port
,
get_open_zmq_ipc_path
,
get_tcp_uri
,
kill_process_tree
)
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.v1.engine.coordinator
import
DPCoordinator
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
STARTUP_POLL_PERIOD_MS
=
10000
class
ConstantList
(
Generic
[
T
],
Sequence
):
class
ConstantList
(
Generic
[
T
],
Sequence
):
...
@@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence):
...
@@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence):
return
f
"ConstantList(
{
self
.
_x
}
)"
return
f
"ConstantList(
{
self
.
_x
}
)"
def
get_engine_client_zmq_addr
(
local_only
:
bool
,
host
:
str
,
port
:
int
=
0
)
->
str
:
return
get_open_zmq_ipc_path
()
if
local_only
else
(
get_tcp_uri
(
host
,
port
or
get_open_port
()))
class
APIServerProcessManager
:
"""Manages a group of API server processes.
Handles creation, monitoring, and termination of API server worker
processes. Also monitors extra processes to check if they are healthy.
"""
def
__init__
(
self
,
target_server_fn
:
Callable
,
listen_address
:
str
,
sock
:
Any
,
args
:
argparse
.
Namespace
,
num_servers
:
int
,
input_addresses
:
list
[
str
],
output_addresses
:
list
[
str
],
stats_update_address
:
Optional
[
str
]
=
None
,
):
"""Initialize and start API server worker processes.
Args:
target_server_fn: Function to call for each API server process
listen_address: Address to listen for client connections
sock: Socket for client connections
args: Command line arguments
num_servers: Number of API server processes to start
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
"""
self
.
listen_address
=
listen_address
self
.
sock
=
sock
self
.
args
=
args
# Start API servers
spawn_context
=
multiprocessing
.
get_context
(
"spawn"
)
self
.
processes
:
list
[
BaseProcess
]
=
[]
for
i
,
in_addr
,
out_addr
in
zip
(
range
(
num_servers
),
input_addresses
,
output_addresses
):
client_config
=
{
"input_address"
:
in_addr
,
"output_address"
:
out_addr
,
"client_index"
:
i
}
if
stats_update_address
is
not
None
:
client_config
[
"stats_update_address"
]
=
stats_update_address
proc
=
spawn_context
.
Process
(
target
=
target_server_fn
,
name
=
f
"ApiServer_
{
i
}
"
,
args
=
(
listen_address
,
sock
,
args
,
client_config
))
self
.
processes
.
append
(
proc
)
proc
.
start
()
logger
.
info
(
"Started %d API server processes"
,
len
(
self
.
processes
))
# Shutdown only the API server processes on garbage collection
# The extra processes are managed by their owners
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
)
def
close
(
self
)
->
None
:
self
.
_finalizer
()
class
CoreEngineProcManager
:
class
CoreEngineProcManager
:
"""
"""
Utility class to handle creation, readiness, and shutdown
Utility class to handle creation, readiness, and shutdown
...
@@ -109,7 +191,7 @@ class CoreEngineProcManager:
...
@@ -109,7 +191,7 @@ class CoreEngineProcManager:
local_start_index
:
int
,
local_start_index
:
int
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
on_head_node
:
bool
,
input
_address
:
str
,
handshake
_address
:
str
,
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
):
):
...
@@ -117,12 +199,12 @@ class CoreEngineProcManager:
...
@@ -117,12 +199,12 @@ class CoreEngineProcManager:
common_kwargs
=
{
common_kwargs
=
{
"vllm_config"
:
vllm_config
,
"vllm_config"
:
vllm_config
,
"on_head_node"
:
on_head_node
,
"on_head_node"
:
on_head_node
,
"
input
_address"
:
input
_address
,
"
handshake
_address"
:
handshake
_address
,
"executor_class"
:
executor_class
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
"log_stats"
:
log_stats
,
}
}
self
.
processes
:
list
[
Process
]
=
[]
self
.
processes
:
list
[
Base
Process
]
=
[]
for
index
in
range
(
local_engine_count
):
for
index
in
range
(
local_engine_count
):
local_index
=
local_start_index
+
index
local_index
=
local_start_index
+
index
global_index
=
start_index
+
index
global_index
=
start_index
+
index
...
@@ -135,8 +217,7 @@ class CoreEngineProcManager:
...
@@ -135,8 +217,7 @@ class CoreEngineProcManager:
"local_dp_rank"
:
local_index
,
"local_dp_rank"
:
local_index
,
}))
}))
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
,
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
)
input_address
)
try
:
try
:
for
proc
in
self
.
processes
:
for
proc
in
self
.
processes
:
proc
.
start
()
proc
.
start
()
...
@@ -164,9 +245,199 @@ class CoreEngineProcManager:
...
@@ -164,9 +245,199 @@ class CoreEngineProcManager:
}
}
class
CoreEngineState
(
Enum
):
NEW
=
auto
()
CONNECTED
=
auto
()
READY
=
auto
()
class
CoreEngine
:
"""One per data parallel rank."""
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
self
.
local
=
local
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
2
,
"little"
)
self
.
state
=
CoreEngineState
.
NEW
@
dataclass
class
EngineZmqAddresses
:
# ZMQ input socket addresses for each front-end client (requests)
inputs
:
list
[
str
]
# ZMQ output socket addresses for each front-end client (responses)
outputs
:
list
[
str
]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input
:
Optional
[
str
]
=
None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output
:
Optional
[
str
]
=
None
@
dataclass
class
EngineHandshakeMetadata
:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses
:
EngineZmqAddresses
parallel_config
:
dict
[
str
,
Union
[
int
,
str
]]
def
wait_for_engine_startup
(
handshake_socket
:
zmq
.
Socket
,
addresses
:
EngineZmqAddresses
,
core_engines
:
list
[
CoreEngine
],
parallel_config
:
ParallelConfig
,
cache_config
:
CacheConfig
,
proc_manager
:
Optional
[
CoreEngineProcManager
],
coord_process
:
Optional
[
Process
],
):
# Wait for engine core process(es) to send ready messages.
local_count
=
parallel_config
.
data_parallel_size_local
remote_count
=
len
(
core_engines
)
-
local_count
# [local, remote] counts
conn_pending
,
start_pending
=
[
local_count
,
remote_count
],
[
0
,
0
]
poller
=
zmq
.
Poller
()
poller
.
register
(
handshake_socket
,
zmq
.
POLLIN
)
if
proc_manager
is
not
None
:
for
sentinel
in
proc_manager
.
sentinels
():
poller
.
register
(
sentinel
,
zmq
.
POLLIN
)
if
coord_process
is
not
None
:
poller
.
register
(
coord_process
.
sentinel
,
zmq
.
POLLIN
)
while
any
(
conn_pending
)
or
any
(
start_pending
):
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
if
not
events
:
if
any
(
conn_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect."
,
*
conn_pending
)
if
any
(
start_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to start."
,
*
start_pending
)
continue
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
handshake_socket
:
# One of the local core processes exited.
finished
=
proc_manager
.
finished_procs
()
if
proc_manager
else
{}
if
coord_process
is
not
None
and
coord_process
.
exitcode
is
not
None
:
finished
[
coord_process
.
name
]
=
coord_process
.
exitcode
raise
RuntimeError
(
"Engine core initialization failed. "
"See root cause above. "
f
"Failed core proc(s):
{
finished
}
"
)
# Receive HELLO and READY messages from the input socket.
eng_identity
,
ready_msg_bytes
=
handshake_socket
.
recv_multipart
()
eng_index
=
int
.
from_bytes
(
eng_identity
,
"little"
)
engine
=
next
((
e
for
e
in
core_engines
if
e
.
identity
==
eng_identity
),
None
)
if
engine
is
None
:
raise
RuntimeError
(
f
"Message from engine with unexpected data "
f
"parallel rank:
{
eng_index
}
"
)
msg
=
msgspec
.
msgpack
.
decode
(
ready_msg_bytes
)
status
,
local
=
msg
[
"status"
],
msg
[
"local"
]
if
local
!=
engine
.
local
:
raise
RuntimeError
(
f
"
{
status
}
message from "
f
"
{
'local'
if
local
else
'remote'
}
"
f
"engine
{
eng_index
}
, expected it to be "
f
"
{
'local'
if
engine
.
local
else
'remote'
}
"
)
if
status
==
"HELLO"
and
engine
.
state
==
CoreEngineState
.
NEW
:
# Send init message with DP config info.
init_message
=
msgspec
.
msgpack
.
encode
(
EngineHandshakeMetadata
(
addresses
=
addresses
,
parallel_config
=
{
"data_parallel_master_ip"
:
parallel_config
.
data_parallel_master_ip
,
"data_parallel_master_port"
:
parallel_config
.
data_parallel_master_port
,
"data_parallel_size"
:
parallel_config
.
data_parallel_size
,
}))
handshake_socket
.
send_multipart
((
eng_identity
,
init_message
),
copy
=
False
)
conn_pending
[
0
if
local
else
1
]
-=
1
start_pending
[
0
if
local
else
1
]
+=
1
engine
.
state
=
CoreEngineState
.
CONNECTED
elif
status
==
"READY"
and
(
engine
.
state
==
CoreEngineState
.
CONNECTED
):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
or
0
num_gpu_blocks
+=
msg
[
"num_gpu_blocks"
]
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
start_pending
[
0
if
local
else
1
]
-=
1
engine
.
state
=
CoreEngineState
.
READY
else
:
raise
RuntimeError
(
f
"Unexpected
{
status
}
message for "
f
"
{
'local'
if
local
else
'remote'
}
engine "
f
"
{
eng_index
}
in
{
engine
.
state
}
state."
)
logger
.
debug
(
"%s from %s core engine process %s."
,
status
,
"local"
if
local
else
"remote"
,
eng_index
)
def
wait_for_completion_or_failure
(
api_server_manager
:
APIServerProcessManager
,
local_engine_manager
:
Optional
[
CoreEngineProcManager
]
=
None
,
coordinator
:
Optional
[
"DPCoordinator"
]
=
None
)
->
None
:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
"""
try
:
logger
.
info
(
"Waiting for API servers to complete ..."
)
# Create a mapping of sentinels to their corresponding processes
# for efficient lookup
sentinel_to_proc
:
dict
[
Any
,
BaseProcess
]
=
{
proc
.
sentinel
:
proc
for
proc
in
api_server_manager
.
processes
}
if
coordinator
:
sentinel_to_proc
[
coordinator
.
proc
.
sentinel
]
=
coordinator
.
proc
if
local_engine_manager
:
for
proc
in
local_engine_manager
.
processes
:
sentinel_to_proc
[
proc
.
sentinel
]
=
proc
# Check if any process terminates
while
sentinel_to_proc
:
# Wait for any process to terminate
ready_sentinels
:
list
[
Any
]
=
connection
.
wait
(
sentinel_to_proc
)
# Process any terminated processes
for
sentinel
in
ready_sentinels
:
proc
=
sentinel_to_proc
.
pop
(
sentinel
)
# Check if process exited with error
if
proc
.
exitcode
!=
0
:
raise
RuntimeError
(
f
"Process
{
proc
.
name
}
(PID:
{
proc
.
pid
}
) "
f
"died with exit code
{
proc
.
exitcode
}
"
)
except
KeyboardInterrupt
:
logger
.
info
(
"Received KeyboardInterrupt, shutting down API servers..."
)
except
Exception
as
e
:
logger
.
exception
(
"Exception occurred while running API servers: %s"
,
str
(
e
))
raise
finally
:
logger
.
info
(
"Terminating remaining processes ..."
)
api_server_manager
.
close
()
if
coordinator
:
coordinator
.
close
()
if
local_engine_manager
:
local_engine_manager
.
close
()
# Note(rob): shutdown function cannot be a bound method,
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the obje
decoup
ct.
# else the gc cannot collect the object.
def
shutdown
(
procs
:
list
[
Process
]
,
input_address
:
str
):
def
shutdown
(
procs
:
list
[
Base
Process
]):
# Shutdown the process.
# Shutdown the process.
for
proc
in
procs
:
for
proc
in
procs
:
if
proc
.
is_alive
():
if
proc
.
is_alive
():
...
@@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str):
...
@@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str):
if
proc
.
is_alive
()
and
(
pid
:
=
proc
.
pid
)
is
not
None
:
if
proc
.
is_alive
()
and
(
pid
:
=
proc
.
pid
)
is
not
None
:
kill_process_tree
(
pid
)
kill_process_tree
(
pid
)
# Remove zmq ipc socket files.
if
input_address
.
startswith
(
"ipc://"
):
socket_file
=
input_address
[
len
(
"ipc://"
):]
if
os
and
os
.
path
.
exists
(
socket_file
):
os
.
remove
(
socket_file
)
def
bind_kv_cache
(
def
bind_kv_cache
(
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment