Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2dbe8c07
Unverified
Commit
2dbe8c07
authored
May 30, 2025
by
Nick Hill
Committed by
GitHub
May 30, 2025
Browse files
[Perf] API-server scaleout with many-to-many server-engine comms (#17546)
parent
84ec470f
Changes
26
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
818 additions
and
353 deletions
+818
-353
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+166
-87
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+244
-212
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+45
-34
vllm/v1/metrics/prometheus.py
vllm/v1/metrics/prometheus.py
+77
-0
vllm/v1/request.py
vllm/v1/request.py
+3
-2
vllm/v1/utils.py
vllm/v1/utils.py
+283
-18
No files found.
vllm/v1/engine/core.py
View file @
2dbe8c07
This diff is collapsed.
Click to expand it.
vllm/v1/engine/core_client.py
View file @
2dbe8c07
This diff is collapsed.
Click to expand it.
vllm/v1/metrics/loggers.py
View file @
2dbe8c07
...
@@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig
...
@@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.prometheus
import
unregister_vllm_metrics
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingLogging
,
SpecDecodingProm
from
vllm.v1.spec_decode.metrics
import
SpecDecodingLogging
,
SpecDecodingProm
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5.0
StatLoggerFactory
=
Callable
[[
VllmConfig
,
int
],
"StatLoggerBase"
]
StatLoggerFactory
=
Callable
[[
VllmConfig
,
int
],
"StatLoggerBase"
]
...
@@ -35,7 +34,7 @@ class StatLoggerBase(ABC):
...
@@ -35,7 +34,7 @@ class StatLoggerBase(ABC):
...
...
@
abstractmethod
@
abstractmethod
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
iteration_stats
:
Optional
[
IterationStats
]):
...
...
...
@@ -78,14 +77,16 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -78,14 +77,16 @@ class LoggingStatLogger(StatLoggerBase):
# Compute summary metrics for tracked stats
# Compute summary metrics for tracked stats
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
self
.
last_log_time
))
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
self
.
last_log_time
))
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
iteration_stats
:
Optional
[
IterationStats
]):
"""Log Stats to standard output."""
"""Log Stats to standard output."""
if
iteration_stats
:
if
iteration_stats
:
self
.
_track_iteration_stats
(
iteration_stats
)
self
.
_track_iteration_stats
(
iteration_stats
)
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
if
scheduler_stats
is
not
None
:
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_logging
.
observe
(
self
.
spec_decoding_logging
.
observe
(
...
@@ -131,9 +132,10 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -131,9 +132,10 @@ class LoggingStatLogger(StatLoggerBase):
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
def
log_engine_initialized
(
self
):
def
log_engine_initialized
(
self
):
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
logger
.
info
(
logger
.
info
(
"
vllm cache_config_info with initialization "
\
"Engine %03d:
vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d"
,
"after num_gpu_blocks is: %d"
,
self
.
engine_index
,
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
)
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
)
...
@@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase):
_spec_decoding_cls
=
SpecDecodingProm
_spec_decoding_cls
=
SpecDecodingProm
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
self
.
_unregister_vllm_metrics
()
unregister_vllm_metrics
()
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
engine_index
=
engine_index
self
.
engine_index
=
engine_index
# Use this flag to hide metrics that were deprecated in
# Use this flag to hide metrics that were deprecated in
...
@@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_scheduler_running
=
self
.
_gauge_cls
(
self
.
gauge_scheduler_running
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_running"
,
name
=
"vllm:num_requests_running"
,
documentation
=
"Number of requests in model execution batches."
,
documentation
=
"Number of requests in model execution batches."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
gauge_scheduler_waiting
=
self
.
_gauge_cls
(
self
.
gauge_scheduler_waiting
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_waiting"
,
name
=
"vllm:num_requests_waiting"
,
documentation
=
"Number of requests waiting to be processed."
,
documentation
=
"Number of requests waiting to be processed."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
#
#
...
@@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_gpu_cache_usage
=
self
.
_gauge_cls
(
self
.
gauge_gpu_cache_usage
=
self
.
_gauge_cls
(
name
=
"vllm:gpu_cache_usage_perc"
,
name
=
"vllm:gpu_cache_usage_perc"
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
multiprocess_mode
=
"mostrecent"
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_gpu_prefix_cache_queries
=
self
.
_counter_cls
(
self
.
counter_gpu_prefix_cache_queries
=
self
.
_counter_cls
(
...
@@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase):
buckets
=
build_1_2_5_buckets
(
max_model_len
),
buckets
=
build_1_2_5_buckets
(
max_model_len
),
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
# See: https://github.com/vllm-project/vllm/pull/18053
self
.
histogram_iteration_tokens
=
\
self
.
histogram_iteration_tokens
=
\
self
.
_histogram_cls
(
self
.
_histogram_cls
(
name
=
"vllm:iteration_tokens_total"
,
name
=
"vllm:iteration_tokens_total"
,
...
@@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase):
#
#
# LoRA metrics
# LoRA metrics
#
#
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
self
.
gauge_lora_info
:
Optional
[
prometheus_client
.
Gauge
]
=
None
self
.
gauge_lora_info
:
Optional
[
prometheus_client
.
Gauge
]
=
None
if
vllm_config
.
lora_config
is
not
None
:
if
vllm_config
.
lora_config
is
not
None
:
self
.
labelname_max_lora
=
"max_lora"
self
.
labelname_max_lora
=
"max_lora"
...
@@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
_gauge_cls
(
self
.
_gauge_cls
(
name
=
"vllm:lora_requests_info"
,
name
=
"vllm:lora_requests_info"
,
documentation
=
"Running stats on lora requests."
,
documentation
=
"Running stats on lora requests."
,
multiprocess_mode
=
"sum"
,
labelnames
=
[
labelnames
=
[
self
.
labelname_max_lora
,
self
.
labelname_max_lora
,
self
.
labelname_waiting_lora_adapters
,
self
.
labelname_waiting_lora_adapters
,
self
.
labelname_running_lora_adapters
,
self
.
labelname_running_lora_adapters
,
])
],
)
def
log_metrics_info
(
self
,
type
:
str
,
config_obj
:
SupportsMetricsInfo
):
def
log_metrics_info
(
self
,
type
:
str
,
config_obj
:
SupportsMetricsInfo
):
metrics_info
=
config_obj
.
metrics_info
()
metrics_info
=
config_obj
.
metrics_info
()
metrics_info
[
"engine"
]
=
self
.
engine_index
metrics_info
[
"engine"
]
=
self
.
engine_index
...
@@ -372,12 +387,15 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -372,12 +387,15 @@ class PrometheusStatLogger(StatLoggerBase):
info_gauge
=
self
.
_gauge_cls
(
info_gauge
=
self
.
_gauge_cls
(
name
=
name
,
name
=
name
,
documentation
=
documentation
,
documentation
=
documentation
,
labelnames
=
metrics_info
.
keys
()).
labels
(
**
metrics_info
)
multiprocess_mode
=
"mostrecent"
,
labelnames
=
metrics_info
.
keys
(),
).
labels
(
**
metrics_info
)
info_gauge
.
set
(
1
)
info_gauge
.
set
(
1
)
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
]):
iteration_stats
:
Optional
[
IterationStats
]):
"""Log to prometheus."""
"""Log to prometheus."""
if
scheduler_stats
is
not
None
:
self
.
gauge_scheduler_running
.
set
(
scheduler_stats
.
num_running_reqs
)
self
.
gauge_scheduler_running
.
set
(
scheduler_stats
.
num_running_reqs
)
self
.
gauge_scheduler_waiting
.
set
(
scheduler_stats
.
num_waiting_reqs
)
self
.
gauge_scheduler_waiting
.
set
(
scheduler_stats
.
num_waiting_reqs
)
...
@@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
gauge_lora_info
.
labels
(
**
lora_info_labels
)
\
self
.
gauge_lora_info
.
labels
(
**
lora_info_labels
)
\
.
set_to_current_time
()
.
set_to_current_time
()
@
staticmethod
def
_unregister_vllm_metrics
():
# Unregister any existing vLLM collectors (for CI/CD
for
collector
in
list
(
prometheus_client
.
REGISTRY
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
prometheus_client
.
REGISTRY
.
unregister
(
collector
)
def
log_engine_initialized
(
self
):
def
log_engine_initialized
(
self
):
self
.
log_metrics_info
(
"cache_config"
,
self
.
vllm_config
.
cache_config
)
self
.
log_metrics_info
(
"cache_config"
,
self
.
vllm_config
.
cache_config
)
...
...
vllm/v1/metrics/prometheus.py
0 → 100644
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
os
import
tempfile
from
typing
import
Optional
from
prometheus_client
import
REGISTRY
,
CollectorRegistry
,
multiprocess
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
# Global temporary directory for prometheus multiprocessing
_prometheus_multiproc_dir
:
Optional
[
tempfile
.
TemporaryDirectory
]
=
None
def
setup_multiprocess_prometheus
():
"""Set up prometheus multiprocessing directory if not already configured.
"""
global
_prometheus_multiproc_dir
if
"PROMETHEUS_MULTIPROC_DIR"
not
in
os
.
environ
:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
_prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
()
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
=
_prometheus_multiproc_dir
.
name
logger
.
debug
(
"Created PROMETHEUS_MULTIPROC_DIR at %s"
,
_prometheus_multiproc_dir
.
name
)
else
:
logger
.
warning
(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup."
)
def
get_prometheus_registry
():
"""Get the appropriate prometheus registry based on multiprocessing
configuration.
Returns:
Registry: A prometheus registry
"""
if
os
.
getenv
(
"PROMETHEUS_MULTIPROC_DIR"
)
is
not
None
:
logger
.
debug
(
"Using multiprocess registry for prometheus metrics"
)
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
return
registry
return
REGISTRY
def
unregister_vllm_metrics
():
"""Unregister any existing vLLM collectors from the prometheus registry.
This is useful for testing and CI/CD where metrics may be registered
multiple times across test runs.
Also, in case of multiprocess, we need to unregister the metrics from the
global registry.
"""
registry
=
REGISTRY
# Unregister any existing vLLM collectors
for
collector
in
list
(
registry
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
registry
.
unregister
(
collector
)
def
shutdown_prometheus
():
"""Shutdown prometheus metrics."""
try
:
pid
=
os
.
getpid
()
multiprocess
.
mark_process_dead
(
pid
)
logger
.
debug
(
"Marked Prometheus metrics for process %d as dead"
,
pid
)
except
Exception
as
e
:
logger
.
error
(
"Error during metrics cleanup: %s"
,
str
(
e
))
vllm/v1/request.py
View file @
2dbe8c07
...
@@ -26,12 +26,13 @@ class Request:
...
@@ -26,12 +26,13 @@ class Request:
multi_modal_placeholders
:
Optional
[
list
[
PlaceholderRange
]],
multi_modal_placeholders
:
Optional
[
list
[
PlaceholderRange
]],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
eos_token_id
:
Optional
[
int
],
eos_token_id
:
Optional
[
int
],
arrival_time
:
float
,
client_index
:
int
=
0
,
lora_request
:
Optional
[
"LoRARequest"
]
=
None
,
lora_request
:
Optional
[
"LoRARequest"
]
=
None
,
structured_output_request
:
Optional
[
"StructuredOutputRequest"
]
=
None
,
structured_output_request
:
Optional
[
"StructuredOutputRequest"
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
client_index
=
client_index
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
# Because of LoRA, the eos token id can be different for each request.
# Because of LoRA, the eos token id can be different for each request.
self
.
eos_token_id
=
eos_token_id
self
.
eos_token_id
=
eos_token_id
...
@@ -90,13 +91,13 @@ class Request:
...
@@ -90,13 +91,13 @@ class Request:
return
cls
(
return
cls
(
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
client_index
=
request
.
client_index
,
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt_token_ids
=
request
.
prompt_token_ids
,
multi_modal_inputs
=
request
.
mm_inputs
,
multi_modal_inputs
=
request
.
mm_inputs
,
multi_modal_hashes
=
request
.
mm_hashes
,
multi_modal_hashes
=
request
.
mm_hashes
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
sampling_params
=
request
.
sampling_params
,
sampling_params
=
request
.
sampling_params
,
eos_token_id
=
request
.
eos_token_id
,
eos_token_id
=
request
.
eos_token_id
,
arrival_time
=
request
.
arrival_time
,
lora_request
=
request
.
lora_request
,
lora_request
=
request
.
lora_request
,
structured_output_request
=
StructuredOutputRequest
(
structured_output_request
=
StructuredOutputRequest
(
sampling_params
=
request
.
sampling_params
),
sampling_params
=
request
.
sampling_params
),
...
...
vllm/v1/utils.py
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
import
argparse
import
multiprocessing
import
time
import
time
import
weakref
import
weakref
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
multiprocessing
import
Process
,
connection
from
multiprocessing
import
Process
,
connection
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
from
multiprocessing.process
import
BaseProcess
overload
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
)
import
msgspec
import
torch
import
torch
import
zmq
from
vllm.config
import
VllmConfig
from
vllm.config
import
CacheConfig
,
ParallelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.utils
import
get_mp_context
,
kill_process_tree
from
vllm.utils
import
(
get_mp_context
,
get_open_port
,
get_open_zmq_ipc_path
,
get_tcp_uri
,
kill_process_tree
)
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.v1.engine.coordinator
import
DPCoordinator
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
STARTUP_POLL_PERIOD_MS
=
10000
class
ConstantList
(
Generic
[
T
],
Sequence
):
class
ConstantList
(
Generic
[
T
],
Sequence
):
...
@@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence):
...
@@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence):
return
f
"ConstantList(
{
self
.
_x
}
)"
return
f
"ConstantList(
{
self
.
_x
}
)"
def
get_engine_client_zmq_addr
(
local_only
:
bool
,
host
:
str
,
port
:
int
=
0
)
->
str
:
return
get_open_zmq_ipc_path
()
if
local_only
else
(
get_tcp_uri
(
host
,
port
or
get_open_port
()))
class
APIServerProcessManager
:
"""Manages a group of API server processes.
Handles creation, monitoring, and termination of API server worker
processes. Also monitors extra processes to check if they are healthy.
"""
def
__init__
(
self
,
target_server_fn
:
Callable
,
listen_address
:
str
,
sock
:
Any
,
args
:
argparse
.
Namespace
,
num_servers
:
int
,
input_addresses
:
list
[
str
],
output_addresses
:
list
[
str
],
stats_update_address
:
Optional
[
str
]
=
None
,
):
"""Initialize and start API server worker processes.
Args:
target_server_fn: Function to call for each API server process
listen_address: Address to listen for client connections
sock: Socket for client connections
args: Command line arguments
num_servers: Number of API server processes to start
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
"""
self
.
listen_address
=
listen_address
self
.
sock
=
sock
self
.
args
=
args
# Start API servers
spawn_context
=
multiprocessing
.
get_context
(
"spawn"
)
self
.
processes
:
list
[
BaseProcess
]
=
[]
for
i
,
in_addr
,
out_addr
in
zip
(
range
(
num_servers
),
input_addresses
,
output_addresses
):
client_config
=
{
"input_address"
:
in_addr
,
"output_address"
:
out_addr
,
"client_index"
:
i
}
if
stats_update_address
is
not
None
:
client_config
[
"stats_update_address"
]
=
stats_update_address
proc
=
spawn_context
.
Process
(
target
=
target_server_fn
,
name
=
f
"ApiServer_
{
i
}
"
,
args
=
(
listen_address
,
sock
,
args
,
client_config
))
self
.
processes
.
append
(
proc
)
proc
.
start
()
logger
.
info
(
"Started %d API server processes"
,
len
(
self
.
processes
))
# Shutdown only the API server processes on garbage collection
# The extra processes are managed by their owners
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
)
def
close
(
self
)
->
None
:
self
.
_finalizer
()
class
CoreEngineProcManager
:
class
CoreEngineProcManager
:
"""
"""
Utility class to handle creation, readiness, and shutdown
Utility class to handle creation, readiness, and shutdown
...
@@ -109,7 +191,7 @@ class CoreEngineProcManager:
...
@@ -109,7 +191,7 @@ class CoreEngineProcManager:
local_start_index
:
int
,
local_start_index
:
int
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
on_head_node
:
bool
,
on_head_node
:
bool
,
input
_address
:
str
,
handshake
_address
:
str
,
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
):
):
...
@@ -117,12 +199,12 @@ class CoreEngineProcManager:
...
@@ -117,12 +199,12 @@ class CoreEngineProcManager:
common_kwargs
=
{
common_kwargs
=
{
"vllm_config"
:
vllm_config
,
"vllm_config"
:
vllm_config
,
"on_head_node"
:
on_head_node
,
"on_head_node"
:
on_head_node
,
"
input
_address"
:
input
_address
,
"
handshake
_address"
:
handshake
_address
,
"executor_class"
:
executor_class
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
"log_stats"
:
log_stats
,
}
}
self
.
processes
:
list
[
Process
]
=
[]
self
.
processes
:
list
[
Base
Process
]
=
[]
for
index
in
range
(
local_engine_count
):
for
index
in
range
(
local_engine_count
):
local_index
=
local_start_index
+
index
local_index
=
local_start_index
+
index
global_index
=
start_index
+
index
global_index
=
start_index
+
index
...
@@ -135,8 +217,7 @@ class CoreEngineProcManager:
...
@@ -135,8 +217,7 @@ class CoreEngineProcManager:
"local_dp_rank"
:
local_index
,
"local_dp_rank"
:
local_index
,
}))
}))
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
,
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
processes
)
input_address
)
try
:
try
:
for
proc
in
self
.
processes
:
for
proc
in
self
.
processes
:
proc
.
start
()
proc
.
start
()
...
@@ -164,9 +245,199 @@ class CoreEngineProcManager:
...
@@ -164,9 +245,199 @@ class CoreEngineProcManager:
}
}
class
CoreEngineState
(
Enum
):
NEW
=
auto
()
CONNECTED
=
auto
()
READY
=
auto
()
class
CoreEngine
:
"""One per data parallel rank."""
def
__init__
(
self
,
index
:
int
=
0
,
local
:
bool
=
True
):
self
.
local
=
local
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
2
,
"little"
)
self
.
state
=
CoreEngineState
.
NEW
@
dataclass
class
EngineZmqAddresses
:
# ZMQ input socket addresses for each front-end client (requests)
inputs
:
list
[
str
]
# ZMQ output socket addresses for each front-end client (responses)
outputs
:
list
[
str
]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input
:
Optional
[
str
]
=
None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output
:
Optional
[
str
]
=
None
@
dataclass
class
EngineHandshakeMetadata
:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses
:
EngineZmqAddresses
parallel_config
:
dict
[
str
,
Union
[
int
,
str
]]
def
wait_for_engine_startup
(
handshake_socket
:
zmq
.
Socket
,
addresses
:
EngineZmqAddresses
,
core_engines
:
list
[
CoreEngine
],
parallel_config
:
ParallelConfig
,
cache_config
:
CacheConfig
,
proc_manager
:
Optional
[
CoreEngineProcManager
],
coord_process
:
Optional
[
Process
],
):
# Wait for engine core process(es) to send ready messages.
local_count
=
parallel_config
.
data_parallel_size_local
remote_count
=
len
(
core_engines
)
-
local_count
# [local, remote] counts
conn_pending
,
start_pending
=
[
local_count
,
remote_count
],
[
0
,
0
]
poller
=
zmq
.
Poller
()
poller
.
register
(
handshake_socket
,
zmq
.
POLLIN
)
if
proc_manager
is
not
None
:
for
sentinel
in
proc_manager
.
sentinels
():
poller
.
register
(
sentinel
,
zmq
.
POLLIN
)
if
coord_process
is
not
None
:
poller
.
register
(
coord_process
.
sentinel
,
zmq
.
POLLIN
)
while
any
(
conn_pending
)
or
any
(
start_pending
):
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
if
not
events
:
if
any
(
conn_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect."
,
*
conn_pending
)
if
any
(
start_pending
):
logger
.
debug
(
"Waiting for %d local, %d remote core engine proc(s) "
"to start."
,
*
start_pending
)
continue
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
handshake_socket
:
# One of the local core processes exited.
finished
=
proc_manager
.
finished_procs
()
if
proc_manager
else
{}
if
coord_process
is
not
None
and
coord_process
.
exitcode
is
not
None
:
finished
[
coord_process
.
name
]
=
coord_process
.
exitcode
raise
RuntimeError
(
"Engine core initialization failed. "
"See root cause above. "
f
"Failed core proc(s):
{
finished
}
"
)
# Receive HELLO and READY messages from the input socket.
eng_identity
,
ready_msg_bytes
=
handshake_socket
.
recv_multipart
()
eng_index
=
int
.
from_bytes
(
eng_identity
,
"little"
)
engine
=
next
((
e
for
e
in
core_engines
if
e
.
identity
==
eng_identity
),
None
)
if
engine
is
None
:
raise
RuntimeError
(
f
"Message from engine with unexpected data "
f
"parallel rank:
{
eng_index
}
"
)
msg
=
msgspec
.
msgpack
.
decode
(
ready_msg_bytes
)
status
,
local
=
msg
[
"status"
],
msg
[
"local"
]
if
local
!=
engine
.
local
:
raise
RuntimeError
(
f
"
{
status
}
message from "
f
"
{
'local'
if
local
else
'remote'
}
"
f
"engine
{
eng_index
}
, expected it to be "
f
"
{
'local'
if
engine
.
local
else
'remote'
}
"
)
if
status
==
"HELLO"
and
engine
.
state
==
CoreEngineState
.
NEW
:
# Send init message with DP config info.
init_message
=
msgspec
.
msgpack
.
encode
(
EngineHandshakeMetadata
(
addresses
=
addresses
,
parallel_config
=
{
"data_parallel_master_ip"
:
parallel_config
.
data_parallel_master_ip
,
"data_parallel_master_port"
:
parallel_config
.
data_parallel_master_port
,
"data_parallel_size"
:
parallel_config
.
data_parallel_size
,
}))
handshake_socket
.
send_multipart
((
eng_identity
,
init_message
),
copy
=
False
)
conn_pending
[
0
if
local
else
1
]
-=
1
start_pending
[
0
if
local
else
1
]
+=
1
engine
.
state
=
CoreEngineState
.
CONNECTED
elif
status
==
"READY"
and
(
engine
.
state
==
CoreEngineState
.
CONNECTED
):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
or
0
num_gpu_blocks
+=
msg
[
"num_gpu_blocks"
]
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
start_pending
[
0
if
local
else
1
]
-=
1
engine
.
state
=
CoreEngineState
.
READY
else
:
raise
RuntimeError
(
f
"Unexpected
{
status
}
message for "
f
"
{
'local'
if
local
else
'remote'
}
engine "
f
"
{
eng_index
}
in
{
engine
.
state
}
state."
)
logger
.
debug
(
"%s from %s core engine process %s."
,
status
,
"local"
if
local
else
"remote"
,
eng_index
)
def
wait_for_completion_or_failure
(
api_server_manager
:
APIServerProcessManager
,
local_engine_manager
:
Optional
[
CoreEngineProcManager
]
=
None
,
coordinator
:
Optional
[
"DPCoordinator"
]
=
None
)
->
None
:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
"""
try
:
logger
.
info
(
"Waiting for API servers to complete ..."
)
# Create a mapping of sentinels to their corresponding processes
# for efficient lookup
sentinel_to_proc
:
dict
[
Any
,
BaseProcess
]
=
{
proc
.
sentinel
:
proc
for
proc
in
api_server_manager
.
processes
}
if
coordinator
:
sentinel_to_proc
[
coordinator
.
proc
.
sentinel
]
=
coordinator
.
proc
if
local_engine_manager
:
for
proc
in
local_engine_manager
.
processes
:
sentinel_to_proc
[
proc
.
sentinel
]
=
proc
# Check if any process terminates
while
sentinel_to_proc
:
# Wait for any process to terminate
ready_sentinels
:
list
[
Any
]
=
connection
.
wait
(
sentinel_to_proc
)
# Process any terminated processes
for
sentinel
in
ready_sentinels
:
proc
=
sentinel_to_proc
.
pop
(
sentinel
)
# Check if process exited with error
if
proc
.
exitcode
!=
0
:
raise
RuntimeError
(
f
"Process
{
proc
.
name
}
(PID:
{
proc
.
pid
}
) "
f
"died with exit code
{
proc
.
exitcode
}
"
)
except
KeyboardInterrupt
:
logger
.
info
(
"Received KeyboardInterrupt, shutting down API servers..."
)
except
Exception
as
e
:
logger
.
exception
(
"Exception occurred while running API servers: %s"
,
str
(
e
))
raise
finally
:
logger
.
info
(
"Terminating remaining processes ..."
)
api_server_manager
.
close
()
if
coordinator
:
coordinator
.
close
()
if
local_engine_manager
:
local_engine_manager
.
close
()
# Note(rob): shutdown function cannot be a bound method,
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the obje
decoup
ct.
# else the gc cannot collect the object.
def
shutdown
(
procs
:
list
[
Process
]
,
input_address
:
str
):
def
shutdown
(
procs
:
list
[
Base
Process
]):
# Shutdown the process.
# Shutdown the process.
for
proc
in
procs
:
for
proc
in
procs
:
if
proc
.
is_alive
():
if
proc
.
is_alive
():
...
@@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str):
...
@@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str):
if
proc
.
is_alive
()
and
(
pid
:
=
proc
.
pid
)
is
not
None
:
if
proc
.
is_alive
()
and
(
pid
:
=
proc
.
pid
)
is
not
None
:
kill_process_tree
(
pid
)
kill_process_tree
(
pid
)
# Remove zmq ipc socket files.
if
input_address
.
startswith
(
"ipc://"
):
socket_file
=
input_address
[
len
(
"ipc://"
):]
if
os
and
os
.
path
.
exists
(
socket_file
):
os
.
remove
(
socket_file
)
def
bind_kv_cache
(
def
bind_kv_cache
(
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment