Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2dbe8c07
Unverified
Commit
2dbe8c07
authored
May 30, 2025
by
Nick Hill
Committed by
GitHub
May 30, 2025
Browse files
[Perf] API-server scaleout with many-to-many server-engine comms (#17546)
parent
84ec470f
Changes
26
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
818 additions
and
353 deletions
+818
-353
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+166
-87
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+244
-212
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+45
-34
vllm/v1/metrics/prometheus.py
vllm/v1/metrics/prometheus.py
+77
-0
vllm/v1/request.py
vllm/v1/request.py
+3
-2
vllm/v1/utils.py
vllm/v1/utils.py
+283
-18
No files found.
vllm/v1/engine/core.py
View file @
2dbe8c07
This diff is collapsed.
Click to expand it.
vllm/v1/engine/core_client.py
View file @
2dbe8c07
This diff is collapsed.
Click to expand it.
vllm/v1/metrics/loggers.py
View file @
2dbe8c07
...
...
@@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.prometheus
import
unregister_vllm_metrics
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingLogging
,
SpecDecodingProm
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5.0
StatLoggerFactory
=
Callable
[[
VllmConfig
,
int
],
"StatLoggerBase"
]
...
...
@@ -35,7 +34,7 @@ class StatLoggerBase(ABC):
...
@
abstractmethod
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
...
...
...
@@ -78,20 +77,22 @@ class LoggingStatLogger(StatLoggerBase):
# Compute summary metrics for tracked stats
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
self
.
last_log_time
))
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
"""Log Stats to standard output."""
if
iteration_stats
:
self
.
_track_iteration_stats
(
iteration_stats
)
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
if
scheduler_stats
is
not
None
:
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_logging
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_logging
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
self
.
last_scheduler_stats
=
scheduler_stats
self
.
last_scheduler_stats
=
scheduler_stats
def
log
(
self
):
now
=
time
.
monotonic
()
...
...
@@ -131,10 +132,11 @@ class LoggingStatLogger(StatLoggerBase):
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
def
log_engine_initialized
(
self
):
logger
.
info
(
"vllm cache_config_info with initialization "
\
"after num_gpu_blocks is: %d"
,
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
)
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
logger
.
info
(
"Engine %03d: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d"
,
self
.
engine_index
,
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
)
class
PrometheusStatLogger
(
StatLoggerBase
):
...
...
@@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase):
_spec_decoding_cls
=
SpecDecodingProm
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
self
.
_unregister_vllm_metrics
()
unregister_vllm_metrics
()
self
.
vllm_config
=
vllm_config
self
.
engine_index
=
engine_index
# Use this flag to hide metrics that were deprecated in
...
...
@@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_scheduler_running
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_running"
,
documentation
=
"Number of requests in model execution batches."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
gauge_scheduler_waiting
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_waiting"
,
documentation
=
"Number of requests waiting to be processed."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
#
...
...
@@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_gpu_cache_usage
=
self
.
_gauge_cls
(
name
=
"vllm:gpu_cache_usage_perc"
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_gpu_prefix_cache_queries
=
self
.
_counter_cls
(
...
...
@@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase):
buckets
=
build_1_2_5_buckets
(
max_model_len
),
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
# See: https://github.com/vllm-project/vllm/pull/18053
self
.
histogram_iteration_tokens
=
\
self
.
_histogram_cls
(
name
=
"vllm:iteration_tokens_total"
,
...
...
@@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase):
#
# LoRA metrics
#
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
self
.
gauge_lora_info
:
Optional
[
prometheus_client
.
Gauge
]
=
None
if
vllm_config
.
lora_config
is
not
None
:
self
.
labelname_max_lora
=
"max_lora"
...
...
@@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
_gauge_cls
(
name
=
"vllm:lora_requests_info"
,
documentation
=
"Running stats on lora requests."
,
multiprocess_mode
=
"sum"
,
labelnames
=
[
self
.
labelname_max_lora
,
self
.
labelname_waiting_lora_adapters
,
self
.
labelname_running_lora_adapters
,
])
],
)
def
log_metrics_info
(
self
,
type
:
str
,
config_obj
:
SupportsMetricsInfo
):
metrics_info
=
config_obj
.
metrics_info
()
metrics_info
[
"engine"
]
=
self
.
engine_index
...
...
@@ -372,25 +387,28 @@ class PrometheusStatLogger(StatLoggerBase):
info_gauge
=
self
.
_gauge_cls
(
name
=
name
,
documentation
=
documentation
,
labelnames
=
metrics_info
.
keys
()).
labels
(
**
metrics_info
)
multiprocess_mode
=
"mostrecent"
,
labelnames
=
metrics_info
.
keys
(),
).
labels
(
**
metrics_info
)
info_gauge
.
set
(
1
)
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
"""Log to prometheus."""
self
.
gauge_scheduler_running
.
set
(
scheduler_stats
.
num_running_reqs
)
self
.
gauge_scheduler_waiting
.
set
(
scheduler_stats
.
num_waiting_reqs
)
if
scheduler_stats
is
not
None
:
self
.
gauge_scheduler_running
.
set
(
scheduler_stats
.
num_running_reqs
)
self
.
gauge_scheduler_waiting
.
set
(
scheduler_stats
.
num_waiting_reqs
)
self
.
gauge_gpu_cache_usage
.
set
(
scheduler_stats
.
gpu_cache_usage
)
self
.
gauge_gpu_cache_usage
.
set
(
scheduler_stats
.
gpu_cache_usage
)
self
.
counter_gpu_prefix_cache_queries
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
queries
)
self
.
counter_gpu_prefix_cache_hits
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
hits
)
self
.
counter_gpu_prefix_cache_queries
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
queries
)
self
.
counter_gpu_prefix_cache_hits
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
hits
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_prom
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_prom
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
if
iteration_stats
is
None
:
return
...
...
@@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_lora_info
.
labels
(
**
lora_info_labels
)
\
.
set_to_current_time
()
@
staticmethod
def
_unregister_vllm_metrics
():
# Unregister any existing vLLM collectors (for CI/CD
for
collector
in
list
(
prometheus_client
.
REGISTRY
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
prometheus_client
.
REGISTRY
.
unregister
(
collector
)
def
log_engine_initialized
(
self
):
self
.
log_metrics_info
(
"cache_config"
,
self
.
vllm_config
.
cache_config
)
...
...
vllm/v1/metrics/prometheus.py
0 → 100644
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
os
import
tempfile
from
typing
import
Optional
from
prometheus_client
import
REGISTRY
,
CollectorRegistry
,
multiprocess
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
# Global temporary directory for prometheus multiprocessing
_prometheus_multiproc_dir
:
Optional
[
tempfile
.
TemporaryDirectory
]
=
None
def
setup_multiprocess_prometheus
():
"""Set up prometheus multiprocessing directory if not already configured.
"""
global
_prometheus_multiproc_dir
if
"PROMETHEUS_MULTIPROC_DIR"
not
in
os
.
environ
:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
_prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
()
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
=
_prometheus_multiproc_dir
.
name
logger
.
debug
(
"Created PROMETHEUS_MULTIPROC_DIR at %s"
,
_prometheus_multiproc_dir
.
name
)
else
:
logger
.
warning
(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup."
)
def
get_prometheus_registry
():
"""Get the appropriate prometheus registry based on multiprocessing
configuration.
Returns:
Registry: A prometheus registry
"""
if
os
.
getenv
(
"PROMETHEUS_MULTIPROC_DIR"
)
is
not
None
:
logger
.
debug
(
"Using multiprocess registry for prometheus metrics"
)
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
return
registry
return
REGISTRY
def
unregister_vllm_metrics
():
"""Unregister any existing vLLM collectors from the prometheus registry.
This is useful for testing and CI/CD where metrics may be registered
multiple times across test runs.
Also, in case of multiprocess, we need to unregister the metrics from the
global registry.
"""
registry
=
REGISTRY
# Unregister any existing vLLM collectors
for
collector
in
list
(
registry
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
registry
.
unregister
(
collector
)
def
shutdown_prometheus
():
"""Shutdown prometheus metrics."""
try
:
pid
=
os
.
getpid
()
multiprocess
.
mark_process_dead
(
pid
)
logger
.
debug
(
"Marked Prometheus metrics for process %d as dead"
,
pid
)
except
Exception
as
e
:
logger
.
error
(
"Error during metrics cleanup: %s"
,
str
(
e
))
vllm/v1/request.py
View file @
2dbe8c07
...
...
@@ -26,12 +26,13 @@ class Request:
multi_modal_placeholders
:
Optional
[
list
[
PlaceholderRange
]],
sampling_params
:
SamplingParams
,
eos_token_id
:
Optional
[
int
],
arrival_time
:
float
,
client_index
:
int
=
0
,
lora_request
:
Optional
[
"LoRARequest"
]
=
None
,
structured_output_request
:
Optional
[
"StructuredOutputRequest"
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
)
->
None
:
self
.
request_id
=
request_id
self
.
client_index
=
client_index
self
.
sampling_params
=
sampling_params
# Because of LoRA, the eos token id can be different for each request.
self
.
eos_token_id
=
eos_token_id
...
...
@@ -90,13 +91,13 @@ class Request:
return
cls
(
request_id
=
request
.
request_id
,
client_index
=
request
.
client_index
,
prompt_token_ids
=
request
.
prompt_token_ids
,
multi_modal_inputs
=
request
.
mm_inputs
,
multi_modal_hashes
=
request
.
mm_hashes
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
sampling_params
=
request
.
sampling_params
,
eos_token_id
=
request
.
eos_token_id
,
arrival_time
=
request
.
arrival_time
,
lora_request
=
request
.
lora_request
,
structured_output_request
=
StructuredOutputRequest
(
sampling_params
=
request
.
sampling_params
),
...
...
vllm/v1/utils.py
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
os
import
argparse
import
multiprocessing
import
time
import
weakref
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
multiprocessing
import
Process
,
connection
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
)
from
multiprocessing.process
import
BaseProcess
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
)
import
msgspec
import
torch
import
zmq
from
vllm.config
import
VllmConfig
from
vllm.config
import
CacheConfig
,
ParallelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
get_mp_context
,
kill_process_tree
from
vllm.utils
import
(
get_mp_context
,
get_open_port
,
get_open_zmq_ipc_path
,
get_tcp_uri
,
kill_process_tree
)
from
vllm.v1.executor.abstract
import
Executor
if
TYPE_CHECKING
:
from
vllm.attention.layer
import
Attention
from
vllm.v1.engine.coordinator
import
DPCoordinator
logger
=
init_logger
(
__name__
)
T
=
TypeVar
(
"T"
)
STARTUP_POLL_PERIOD_MS
=
10000
class
ConstantList
(
Generic
[
T
],
Sequence
):
...
...
@@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence):
return
f
"ConstantList(
{
self
.
_x
}
)"
def
get_engine_client_zmq_addr
(
local_only
:
bool
,
host
:
str
,
port
:
int
=
0
)
->
str
:
return
get_open_zmq_ipc_path
()
if
local_only
else
(
get_tcp_uri
(
host
,
port
or
get_open_port
()))
class
APIServerProcessManager
:
"""Manages a group of API server processes.
Handles creation, monitoring, and termination of API server worker
processes. Also monitors extra processes to check if they are healthy.
"""
def
__init__
(
self
,
target_server_fn
:
Callable
,
listen_address
:
str
,
sock
:
Any
,
args
:
argparse
.
Namespace
,
num_servers
:
int
,
input_addresses
:
list
[
str
],
output_addresses
:
list
[
str
],
stats_update_address
:
Optional
[
str
]
=
None
,
):
"""Initialize and start API server worker processes.
Args:
target_server_fn: Function to call for each API server process
listen_address: Address to listen for client connections
sock: Socket for client connections
args: Command line arguments
num_servers: Number of API server processes to start
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
"""
self
.
listen_address
=
listen_address
self
.
sock
=
sock
self
.
args
=
args
# Start API servers
spawn_context
=
multiprocessing
.
get_context
(
"spawn"
)
self
.
processes
:
list
[
BaseProcess
]
=
[]
for
i
,
in_addr
,
out_addr
in
zip
(
range
(
num_servers
),
input_addresses
,
output_addresses
):
client_config
=
{
"input_address"
:
in_addr
,
"output_address"
:
out_addr
,
"client_index"
:
i
}
if
stats_update_address
is
not
None
:
client_config
[
"stats_update_address"
]
=
stats_update_address
proc
=
spawn_context
.
Process
(
target
=
target_server_fn
,
name
=
f
"ApiServer_
{
i
}
"
,
args
=
(
listen_address
,
sock
,
args
,
client_config
))
self
.
processes
.
append
(
proc
)
proc
.
start
()
logger
.
info
(
"Started %d API server processes"
,
len
(
self
.
processes
))
# Shutdown only the API server processes on garbage collection
# The extra processes are managed by their owners
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
)
def
close
(
self
)
->
None
:
self
.
_finalizer
()
class
CoreEngineProcManager
:
"""
Utility class to handle creation, readiness, and shutdown
...
...
@@ -109,7 +191,7 @@ class CoreEngineProcManager:
local_start_index
:
int
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
input
_address
:
str
,
handshake
_address
:
str
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
):
...
...
@@ -117,12 +199,12 @@ class CoreEngineProcManager:
common_kwargs
=
{
"vllm_config"
:
vllm_config
,
"on_head_node"
:
on_head_node
,
"
input
_address"
:
input
_address
,
"
handshake
_address"
:
handshake
_address
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
}
self
.
processes
:
list
[
Process
]
=
[]
self
.
processes
:
list
[
Base
Process
]
=
[]
for
index
in
range
(
local_engine_count
):
local_index
=
local_start_index
+
index
global_index
=
start_index
+
index
...
...
@@ -135,8 +217,7 @@ class CoreEngineProcManager:
"local_dp_rank"
:
local_index
,
}))
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
,
input_address
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
)
try
:
for
proc
in
self
.
processes
:
proc
.
start
()
...
...
@@ -164,9 +245,199 @@ class CoreEngineProcManager:
}
class
CoreEngineState
(
Enum
):
NEW
=
auto
()
CONNECTED
=
auto
()
READY
=
auto
()
class
CoreEngine
:
"""One per data parallel rank."""
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
self
.
local
=
local
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
2
,
"little"
)
self
.
state
=
CoreEngineState
.
NEW
@
dataclass
class
EngineZmqAddresses
:
# ZMQ input socket addresses for each front-end client (requests)
inputs
:
list
[
str
]
# ZMQ output socket addresses for each front-end client (responses)
outputs
:
list
[
str
]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input
:
Optional
[
str
]
=
None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output
:
Optional
[
str
]
=
None
@
dataclass
class
EngineHandshakeMetadata
:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses
:
EngineZmqAddresses
parallel_config
:
dict
[
str
,
Union
[
int
,
str
]]
def
wait_for_engine_startup
(
handshake_socket
:
zmq
.
Socket
,
addresses
:
EngineZmqAddresses
,
core_engines
:
list
[
CoreEngine
],
parallel_config
:
ParallelConfig
,
cache_config
:
CacheConfig
,
proc_manager
:
Optional
[
CoreEngineProcManager
],
coord_process
:
Optional
[
Process
],
):
# Wait for engine core process(es) to send ready messages.
local_count
=
parallel_config
.
data_parallel_size_local
remote_count
=
len
(
core_engines
)
-
local_count
# [local, remote] counts
conn_pending
,
start_pending
=
[
local_count
,
remote_count
],
[
0
,
0
]
poller
=
zmq
.
Poller
()
poller
.
register
(
handshake_socket
,
zmq
.
POLLIN
)
if
proc_manager
is
not
None
:
for
sentinel
in
proc_manager
.
sentinels
():
poller
.
register
(
sentinel
,
zmq
.
POLLIN
)
if
coord_process
is
not
None
:
poller
.
register
(
coord_process
.
sentinel
,
zmq
.
POLLIN
)
while
any
(
conn_pending
)
or
any
(
start_pending
):
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
if
not
events
:
if
any
(
conn_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect."
,
*
conn_pending
)
if
any
(
start_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to start."
,
*
start_pending
)
continue
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
handshake_socket
:
# One of the local core processes exited.
finished
=
proc_manager
.
finished_procs
()
if
proc_manager
else
{}
if
coord_process
is
not
None
and
coord_process
.
exitcode
is
not
None
:
finished
[
coord_process
.
name
]
=
coord_process
.
exitcode
raise
RuntimeError
(
"Engine core initialization failed. "
"See root cause above. "
f
"Failed core proc(s):
{
finished
}
"
)
# Receive HELLO and READY messages from the input socket.
eng_identity
,
ready_msg_bytes
=
handshake_socket
.
recv_multipart
()
eng_index
=
int
.
from_bytes
(
eng_identity
,
"little"
)
engine
=
next
((
e
for
e
in
core_engines
if
e
.
identity
==
eng_identity
),
None
)
if
engine
is
None
:
raise
RuntimeError
(
f
"Message from engine with unexpected data "
f
"parallel rank:
{
eng_index
}
"
)
msg
=
msgspec
.
msgpack
.
decode
(
ready_msg_bytes
)
status
,
local
=
msg
[
"status"
],
msg
[
"local"
]
if
local
!=
engine
.
local
:
raise
RuntimeError
(
f
"
{
status
}
message from "
f
"
{
'local'
if
local
else
'remote'
}
"
f
"engine
{
eng_index
}
, expected it to be "
f
"
{
'local'
if
engine
.
local
else
'remote'
}
"
)
if
status
==
"HELLO"
and
engine
.
state
==
CoreEngineState
.
NEW
:
# Send init message with DP config info.
init_message
=
msgspec
.
msgpack
.
encode
(
EngineHandshakeMetadata
(
addresses
=
addresses
,
parallel_config
=
{
"data_parallel_master_ip"
:
parallel_config
.
data_parallel_master_ip
,
"data_parallel_master_port"
:
parallel_config
.
data_parallel_master_port
,
"data_parallel_size"
:
parallel_config
.
data_parallel_size
,
}))
handshake_socket
.
send_multipart
((
eng_identity
,
init_message
),
copy
=
False
)
conn_pending
[
0
if
local
else
1
]
-=
1
start_pending
[
0
if
local
else
1
]
+=
1
engine
.
state
=
CoreEngineState
.
CONNECTED
elif
status
==
"READY"
and
(
engine
.
state
==
CoreEngineState
.
CONNECTED
):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
or
0
num_gpu_blocks
+=
msg
[
"num_gpu_blocks"
]
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
start_pending
[
0
if
local
else
1
]
-=
1
engine
.
state
=
CoreEngineState
.
READY
else
:
raise
RuntimeError
(
f
"Unexpected
{
status
}
message for "
f
"
{
'local'
if
local
else
'remote'
}
engine "
f
"
{
eng_index
}
in
{
engine
.
state
}
state."
)
logger
.
debug
(
"%s from %s core engine process %s."
,
status
,
"local"
if
local
else
"remote"
,
eng_index
)
def
wait_for_completion_or_failure
(
api_server_manager
:
APIServerProcessManager
,
local_engine_manager
:
Optional
[
CoreEngineProcManager
]
=
None
,
coordinator
:
Optional
[
"DPCoordinator"
]
=
None
)
->
None
:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
"""
try
:
logger
.
info
(
"Waiting for API servers to complete ..."
)
# Create a mapping of sentinels to their corresponding processes
# for efficient lookup
sentinel_to_proc
:
dict
[
Any
,
BaseProcess
]
=
{
proc
.
sentinel
:
proc
for
proc
in
api_server_manager
.
processes
}
if
coordinator
:
sentinel_to_proc
[
coordinator
.
proc
.
sentinel
]
=
coordinator
.
proc
if
local_engine_manager
:
for
proc
in
local_engine_manager
.
processes
:
sentinel_to_proc
[
proc
.
sentinel
]
=
proc
# Check if any process terminates
while
sentinel_to_proc
:
# Wait for any process to terminate
ready_sentinels
:
list
[
Any
]
=
connection
.
wait
(
sentinel_to_proc
)
# Process any terminated processes
for
sentinel
in
ready_sentinels
:
proc
=
sentinel_to_proc
.
pop
(
sentinel
)
# Check if process exited with error
if
proc
.
exitcode
!=
0
:
raise
RuntimeError
(
f
"Process
{
proc
.
name
}
(PID:
{
proc
.
pid
}
) "
f
"died with exit code
{
proc
.
exitcode
}
"
)
except
KeyboardInterrupt
:
logger
.
info
(
"Received KeyboardInterrupt, shutting down API servers..."
)
except
Exception
as
e
:
logger
.
exception
(
"Exception occurred while running API servers: %s"
,
str
(
e
))
raise
finally
:
logger
.
info
(
"Terminating remaining processes ..."
)
api_server_manager
.
close
()
if
coordinator
:
coordinator
.
close
()
if
local_engine_manager
:
local_engine_manager
.
close
()
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the obje
decoup
ct.
def
shutdown
(
procs
:
list
[
Process
]
,
input_address
:
str
):
# else the gc cannot collect the object.
def
shutdown
(
procs
:
list
[
Base
Process
]):
# Shutdown the process.
for
proc
in
procs
:
if
proc
.
is_alive
():
...
...
@@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str):
if
proc
.
is_alive
()
and
(
pid
:
=
proc
.
pid
)
is
not
None
:
kill_process_tree
(
pid
)
# Remove zmq ipc socket files.
if
input_address
.
startswith
(
"ipc://"
):
socket_file
=
input_address
[
len
(
"ipc://"
):]
if
os
and
os
.
path
.
exists
(
socket_file
):
os
.
remove
(
socket_file
)
def
bind_kv_cache
(
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment