Unverified Commit 2dbe8c07 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Perf] API-server scaleout with many-to-many server-engine comms (#17546)

parent 84ec470f
This diff is collapsed.
This diff is collapsed.
...@@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig ...@@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
...@@ -35,7 +34,7 @@ class StatLoggerBase(ABC): ...@@ -35,7 +34,7 @@ class StatLoggerBase(ABC):
... ...
@abstractmethod @abstractmethod
def record(self, scheduler_stats: SchedulerStats, def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]): iteration_stats: Optional[IterationStats]):
... ...
...@@ -78,14 +77,16 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -78,14 +77,16 @@ class LoggingStatLogger(StatLoggerBase):
# Compute summary metrics for tracked stats # Compute summary metrics for tracked stats
return float(np.sum(tracked_stats) / (now - self.last_log_time)) return float(np.sum(tracked_stats) / (now - self.last_log_time))
def record(self, scheduler_stats: SchedulerStats, def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]): iteration_stats: Optional[IterationStats]):
"""Log Stats to standard output.""" """Log Stats to standard output."""
if iteration_stats: if iteration_stats:
self._track_iteration_stats(iteration_stats) self._track_iteration_stats(iteration_stats)
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) if scheduler_stats is not None:
self.prefix_caching_metrics.observe(
scheduler_stats.prefix_cache_stats)
if scheduler_stats.spec_decoding_stats is not None: if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_logging.observe( self.spec_decoding_logging.observe(
...@@ -131,9 +132,10 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -131,9 +132,10 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging.log(log_fn=log_fn) self.spec_decoding_logging.log(log_fn=log_fn)
def log_engine_initialized(self): def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks:
logger.info( logger.info(
"vllm cache_config_info with initialization " \ "Engine %03d: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d", "after num_gpu_blocks is: %d", self.engine_index,
self.vllm_config.cache_config.num_gpu_blocks) self.vllm_config.cache_config.num_gpu_blocks)
...@@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase):
_spec_decoding_cls = SpecDecodingProm _spec_decoding_cls = SpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self._unregister_vllm_metrics()
unregister_vllm_metrics()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.engine_index = engine_index self.engine_index = engine_index
# Use this flag to hide metrics that were deprecated in # Use this flag to hide metrics that were deprecated in
...@@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase):
self.gauge_scheduler_running = self._gauge_cls( self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running", name="vllm:num_requests_running",
documentation="Number of requests in model execution batches.", documentation="Number of requests in model execution batches.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.gauge_scheduler_waiting = self._gauge_cls( self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting", name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.", documentation="Number of requests waiting to be processed.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
# #
...@@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase):
self.gauge_gpu_cache_usage = self._gauge_cls( self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc", name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.", documentation="GPU KV-cache usage. 1 means 100 percent usage.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_gpu_prefix_cache_queries = self._counter_cls( self.counter_gpu_prefix_cache_queries = self._counter_cls(
...@@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
# See: https://github.com/vllm-project/vllm/pull/18053
self.histogram_iteration_tokens = \ self.histogram_iteration_tokens = \
self._histogram_cls( self._histogram_cls(
name="vllm:iteration_tokens_total", name="vllm:iteration_tokens_total",
...@@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase):
# #
# LoRA metrics # LoRA metrics
# #
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
if vllm_config.lora_config is not None: if vllm_config.lora_config is not None:
self.labelname_max_lora = "max_lora" self.labelname_max_lora = "max_lora"
...@@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase):
self._gauge_cls( self._gauge_cls(
name="vllm:lora_requests_info", name="vllm:lora_requests_info",
documentation="Running stats on lora requests.", documentation="Running stats on lora requests.",
multiprocess_mode="sum",
labelnames=[ labelnames=[
self.labelname_max_lora, self.labelname_max_lora,
self.labelname_waiting_lora_adapters, self.labelname_waiting_lora_adapters,
self.labelname_running_lora_adapters, self.labelname_running_lora_adapters,
]) ],
)
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
metrics_info = config_obj.metrics_info() metrics_info = config_obj.metrics_info()
metrics_info["engine"] = self.engine_index metrics_info["engine"] = self.engine_index
...@@ -372,12 +387,15 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -372,12 +387,15 @@ class PrometheusStatLogger(StatLoggerBase):
info_gauge = self._gauge_cls( info_gauge = self._gauge_cls(
name=name, name=name,
documentation=documentation, documentation=documentation,
labelnames=metrics_info.keys()).labels(**metrics_info) multiprocess_mode="mostrecent",
labelnames=metrics_info.keys(),
).labels(**metrics_info)
info_gauge.set(1) info_gauge.set(1)
def record(self, scheduler_stats: SchedulerStats, def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]): iteration_stats: Optional[IterationStats]):
"""Log to prometheus.""" """Log to prometheus."""
if scheduler_stats is not None:
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
...@@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase):
self.gauge_lora_info.labels(**lora_info_labels)\ self.gauge_lora_info.labels(**lora_info_labels)\
.set_to_current_time() .set_to_current_time()
@staticmethod
def _unregister_vllm_metrics():
# Unregister any existing vLLM collectors (for CI/CD
for collector in list(prometheus_client.REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
prometheus_client.REGISTRY.unregister(collector)
def log_engine_initialized(self): def log_engine_initialized(self):
self.log_metrics_info("cache_config", self.vllm_config.cache_config) self.log_metrics_info("cache_config", self.vllm_config.cache_config)
......
# SPDX-License-Identifier: Apache-2.0
import os
import tempfile
from typing import Optional
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
from vllm.logger import init_logger
logger = init_logger(__name__)
# Global temporary directory for prometheus multiprocessing
_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None
def setup_multiprocess_prometheus():
"""Set up prometheus multiprocessing directory if not already configured.
"""
global _prometheus_multiproc_dir
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
_prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name
logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s",
_prometheus_multiproc_dir.name)
else:
logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
def get_prometheus_registry():
"""Get the appropriate prometheus registry based on multiprocessing
configuration.
Returns:
Registry: A prometheus registry
"""
if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None:
logger.debug("Using multiprocess registry for prometheus metrics")
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
return registry
return REGISTRY
def unregister_vllm_metrics():
"""Unregister any existing vLLM collectors from the prometheus registry.
This is useful for testing and CI/CD where metrics may be registered
multiple times across test runs.
Also, in case of multiprocess, we need to unregister the metrics from the
global registry.
"""
registry = REGISTRY
# Unregister any existing vLLM collectors
for collector in list(registry._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
registry.unregister(collector)
def shutdown_prometheus():
"""Shutdown prometheus metrics."""
try:
pid = os.getpid()
multiprocess.mark_process_dead(pid)
logger.debug("Marked Prometheus metrics for process %d as dead", pid)
except Exception as e:
logger.error("Error during metrics cleanup: %s", str(e))
...@@ -26,12 +26,13 @@ class Request: ...@@ -26,12 +26,13 @@ class Request:
multi_modal_placeholders: Optional[list[PlaceholderRange]], multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams, sampling_params: SamplingParams,
eos_token_id: Optional[int], eos_token_id: Optional[int],
arrival_time: float, client_index: int = 0,
lora_request: Optional["LoRARequest"] = None, lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.client_index = client_index
self.sampling_params = sampling_params self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request. # Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
...@@ -90,13 +91,13 @@ class Request: ...@@ -90,13 +91,13 @@ class Request:
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
multi_modal_inputs=request.mm_inputs, multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes, multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders, multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id, eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request, lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest( structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params), sampling_params=request.sampling_params),
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import argparse
import multiprocessing
import time import time
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection from multiprocessing import Process, connection
from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, from multiprocessing.process import BaseProcess
overload) from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
import msgspec
import torch import torch
import zmq
from vllm.config import VllmConfig from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import get_mp_context, kill_process_tree from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
get_tcp_uri, kill_process_tree)
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.v1.engine.coordinator import DPCoordinator
logger = init_logger(__name__) logger = init_logger(__name__)
T = TypeVar("T") T = TypeVar("T")
STARTUP_POLL_PERIOD_MS = 10000
class ConstantList(Generic[T], Sequence): class ConstantList(Generic[T], Sequence):
...@@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence): ...@@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence):
return f"ConstantList({self._x})" return f"ConstantList({self._x})"
def get_engine_client_zmq_addr(local_only: bool,
host: str,
port: int = 0) -> str:
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
host, port or get_open_port()))
class APIServerProcessManager:
"""Manages a group of API server processes.
Handles creation, monitoring, and termination of API server worker
processes. Also monitors extra processes to check if they are healthy.
"""
def __init__(
self,
target_server_fn: Callable,
listen_address: str,
sock: Any,
args: argparse.Namespace,
num_servers: int,
input_addresses: list[str],
output_addresses: list[str],
stats_update_address: Optional[str] = None,
):
"""Initialize and start API server worker processes.
Args:
target_server_fn: Function to call for each API server process
listen_address: Address to listen for client connections
sock: Socket for client connections
args: Command line arguments
num_servers: Number of API server processes to start
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
"""
self.listen_address = listen_address
self.sock = sock
self.args = args
# Start API servers
spawn_context = multiprocessing.get_context("spawn")
self.processes: list[BaseProcess] = []
for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
output_addresses):
client_config = {
"input_address": in_addr,
"output_address": out_addr,
"client_index": i
}
if stats_update_address is not None:
client_config["stats_update_address"] = stats_update_address
proc = spawn_context.Process(target=target_server_fn,
name=f"ApiServer_{i}",
args=(listen_address, sock, args,
client_config))
self.processes.append(proc)
proc.start()
logger.info("Started %d API server processes", len(self.processes))
# Shutdown only the API server processes on garbage collection
# The extra processes are managed by their owners
self._finalizer = weakref.finalize(self, shutdown, self.processes)
def close(self) -> None:
self._finalizer()
class CoreEngineProcManager: class CoreEngineProcManager:
""" """
Utility class to handle creation, readiness, and shutdown Utility class to handle creation, readiness, and shutdown
...@@ -109,7 +191,7 @@ class CoreEngineProcManager: ...@@ -109,7 +191,7 @@ class CoreEngineProcManager:
local_start_index: int, local_start_index: int,
vllm_config: VllmConfig, vllm_config: VllmConfig,
on_head_node: bool, on_head_node: bool,
input_address: str, handshake_address: str,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
): ):
...@@ -117,12 +199,12 @@ class CoreEngineProcManager: ...@@ -117,12 +199,12 @@ class CoreEngineProcManager:
common_kwargs = { common_kwargs = {
"vllm_config": vllm_config, "vllm_config": vllm_config,
"on_head_node": on_head_node, "on_head_node": on_head_node,
"input_address": input_address, "handshake_address": handshake_address,
"executor_class": executor_class, "executor_class": executor_class,
"log_stats": log_stats, "log_stats": log_stats,
} }
self.processes: list[Process] = [] self.processes: list[BaseProcess] = []
for index in range(local_engine_count): for index in range(local_engine_count):
local_index = local_start_index + index local_index = local_start_index + index
global_index = start_index + index global_index = start_index + index
...@@ -135,8 +217,7 @@ class CoreEngineProcManager: ...@@ -135,8 +217,7 @@ class CoreEngineProcManager:
"local_dp_rank": local_index, "local_dp_rank": local_index,
})) }))
self._finalizer = weakref.finalize(self, shutdown, self.processes, self._finalizer = weakref.finalize(self, shutdown, self.processes)
input_address)
try: try:
for proc in self.processes: for proc in self.processes:
proc.start() proc.start()
...@@ -164,9 +245,199 @@ class CoreEngineProcManager: ...@@ -164,9 +245,199 @@ class CoreEngineProcManager:
} }
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.index = index
self.identity = index.to_bytes(2, "little")
self.state = CoreEngineState.NEW
@dataclass
class EngineZmqAddresses:
# ZMQ input socket addresses for each front-end client (requests)
inputs: list[str]
# ZMQ output socket addresses for each front-end client (responses)
outputs: list[str]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input: Optional[str] = None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output: Optional[str] = None
@dataclass
class EngineHandshakeMetadata:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str]]
def wait_for_engine_startup(
handshake_socket: zmq.Socket,
addresses: EngineZmqAddresses,
core_engines: list[CoreEngine],
parallel_config: ParallelConfig,
cache_config: CacheConfig,
proc_manager: Optional[CoreEngineProcManager],
coord_process: Optional[Process],
):
# Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local
remote_count = len(core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN)
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
if coord_process is not None:
poller.register(coord_process.sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue
if len(events) > 1 or events[0][0] != handshake_socket:
# One of the local core processes exited.
finished = proc_manager.finished_procs() if proc_manager else {}
if coord_process is not None and coord_process.exitcode is not None:
finished[coord_process.name] = coord_process.exitcode
raise RuntimeError("Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}")
# Receive HELLO and READY messages from the input socket.
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
eng_index = int.from_bytes(eng_identity, "little")
engine = next((e for e in core_engines if e.identity == eng_identity),
None)
if engine is None:
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.
init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata(
addresses=addresses,
parallel_config={
"data_parallel_master_ip":
parallel_config.data_parallel_master_ip,
"data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
}))
handshake_socket.send_multipart((eng_identity, init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and (engine.state == CoreEngineState.CONNECTED):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg["num_gpu_blocks"]
cache_config.num_gpu_blocks = num_gpu_blocks
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status,
"local" if local else "remote", eng_index)
def wait_for_completion_or_failure(
api_server_manager: APIServerProcessManager,
local_engine_manager: Optional[CoreEngineProcManager] = None,
coordinator: Optional["DPCoordinator"] = None) -> None:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
"""
try:
logger.info("Waiting for API servers to complete ...")
# Create a mapping of sentinels to their corresponding processes
# for efficient lookup
sentinel_to_proc: dict[Any, BaseProcess] = {
proc.sentinel: proc
for proc in api_server_manager.processes
}
if coordinator:
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
if local_engine_manager:
for proc in local_engine_manager.processes:
sentinel_to_proc[proc.sentinel] = proc
# Check if any process terminates
while sentinel_to_proc:
# Wait for any process to terminate
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
# Process any terminated processes
for sentinel in ready_sentinels:
proc = sentinel_to_proc.pop(sentinel)
# Check if process exited with error
if proc.exitcode != 0:
raise RuntimeError(
f"Process {proc.name} (PID: {proc.pid}) "
f"died with exit code {proc.exitcode}")
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down API servers...")
except Exception as e:
logger.exception("Exception occurred while running API servers: %s",
str(e))
raise
finally:
logger.info("Terminating remaining processes ...")
api_server_manager.close()
if coordinator:
coordinator.close()
if local_engine_manager:
local_engine_manager.close()
# Note(rob): shutdown function cannot be a bound method, # Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the objedecoupct. # else the gc cannot collect the object.
def shutdown(procs: list[Process], input_address: str): def shutdown(procs: list[BaseProcess]):
# Shutdown the process. # Shutdown the process.
for proc in procs: for proc in procs:
if proc.is_alive(): if proc.is_alive():
...@@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str): ...@@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str):
if proc.is_alive() and (pid := proc.pid) is not None: if proc.is_alive() and (pid := proc.pid) is not None:
kill_process_tree(pid) kill_process_tree(pid)
# Remove zmq ipc socket files.
if input_address.startswith("ipc://"):
socket_file = input_address[len("ipc://"):]
if os and os.path.exists(socket_file):
os.remove(socket_file)
def bind_kv_cache( def bind_kv_cache(
kv_caches: dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment