Unverified Commit b5e60825 authored by fangyuchu's avatar fangyuchu Committed by GitHub
Browse files

[Refactor] Unify engine process monitoring in engine manager and add Ray backend support (#35862)


Signed-off-by: default avatarfangyuchu <fangyuchu@qq.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 2c734ed0
......@@ -225,7 +225,7 @@ def run_headless(args: argparse.Namespace):
)
try:
engine_manager.join_first()
engine_manager.monitor_engine_liveness()
finally:
timeout = None
if shutdown_requested:
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import multiprocessing
import queue
import sys
import uuid
......@@ -640,34 +639,20 @@ class MPClient(EngineCoreClient):
def start_engine_core_monitor(self):
"""Start a monitor thread for engine core processes."""
engine_manager = self.resources.engine_manager
if (
engine_manager is None
or not hasattr(engine_manager, "processes")
or not engine_manager.processes
):
if engine_manager is None:
# No engine processes to monitor
return
engine_processes = engine_manager.processes
self_ref = weakref.ref(self)
# Monitor engine core process liveness. If any die unexpectedly,
# logs an error, shuts down the client and invokes the failure
# callback to inform the engine.
# marks the engine as dead, and shuts down the client.
def monitor_engine_cores():
sentinels = [proc.sentinel for proc in engine_processes]
died = multiprocessing.connection.wait(sentinels)
engine_manager.monitor_engine_liveness()
_self = self_ref()
if not _self or not _self._finalizer.alive or _self.resources.engine_dead:
return
_self.resources.engine_dead = True
proc_name = next(
proc.name for proc in engine_processes if proc.sentinel == died[0]
)
logger.error(
"Engine core proc %s died unexpectedly, shutting down client.",
proc_name,
)
_self.shutdown()
# Note: For MPClient, we don't have a failure callback mechanism
# like MultiprocExecutor, but we set engine_dead flag which will
......@@ -1634,6 +1619,9 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
parallel_config = self.vllm_config.parallel_config
ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap()
removed_dp_size = cur_data_parallel_size - new_data_parallel_size
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
self.resources.engine_manager.remove_run_refs_for_scale_down(removed_dp_size)
reconfig_futures = []
for cur_dp_rank, engine in enumerate(self.core_engines):
reconfig_request = ReconfigureDistributedRequest(
......
......@@ -11,7 +11,7 @@ from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
from unittest.mock import patch
import msgspec
......@@ -133,6 +133,8 @@ class CoreEngineProcManager:
)
self._finalizer = weakref.finalize(self, shutdown, self.processes)
self.manager_stopped = threading.Event()
self.failed_proc_name: str | None = None
try:
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
......@@ -154,12 +156,31 @@ class CoreEngineProcManager:
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown engine core processes with configurable timeout."""
self.manager_stopped.set()
if self._finalizer.detach() is not None:
shutdown(self.processes, timeout=timeout)
def join_first(self):
"""Wait for any process to exit."""
connection.wait(proc.sentinel for proc in self.processes)
def monitor_engine_liveness(self) -> None:
"""Monitor engine core process liveness."""
sentinel_to_proc = {proc.sentinel: proc for proc in self.processes}
sentinels = set(sentinel_to_proc.keys())
while sentinels and not self.manager_stopped.is_set():
died_sentinels = connection.wait(sentinels, timeout=1)
for sentinel in died_sentinels:
proc = sentinel_to_proc.pop(cast(int, sentinel))
exitcode = proc.exitcode
if exitcode != 0 and not self.manager_stopped.is_set():
self.failed_proc_name = proc.name
if died_sentinels:
# Any engine exit currently triggers a shutdown. Future
# work (e.g., Elastic and fault-tolerant EP) will add finer-grained
# handling for different exit scenarios.
break
self.shutdown()
def sentinels(self) -> list:
return [proc.sentinel for proc in self.processes]
......@@ -298,6 +319,8 @@ class CoreEngineActorManager:
self.log_stats = log_stats
local_engine_count = vllm_config.parallel_config.data_parallel_size_local
world_size = vllm_config.parallel_config.world_size
self.manager_stopped = threading.Event()
self.failed_proc_name: str | None = None
if ray.is_initialized():
logger.info("Ray is already initialized. Skipping Ray initialization.")
......@@ -395,8 +418,11 @@ class CoreEngineActorManager:
ray.get(refs)
self.run_refs = []
self.actor_run_ref_dict = dict()
for actor in self.local_engine_actors + self.remote_engine_actors:
self.run_refs.append(actor.run.remote())
ref = actor.run.remote()
self.run_refs.append(ref)
self.actor_run_ref_dict[actor] = ref
@staticmethod
def create_dp_placement_groups(
......@@ -776,7 +802,9 @@ class CoreEngineActorManager:
) + self.remote_engine_actors[-(len(placement_groups) - new_local_engines) :]
for actor in actors:
self.run_refs.append(actor.run.remote())
ref = actor.run.remote()
self.run_refs.append(ref)
self.actor_run_ref_dict[actor] = ref
cur_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# Update old_vllm_config with new data_parallel_size_local if any new
......@@ -805,12 +833,59 @@ class CoreEngineActorManager:
self.remote_engine_actors.pop()
ray.util.remove_placement_group(pg)
def remove_run_refs_for_scale_down(self, removed_dp_size: int) -> None:
if removed_dp_size <= 0:
return
flags = self.placement_group_is_local[-removed_dp_size:]
li = len(self.local_engine_actors) - 1
ri = len(self.remote_engine_actors) - 1
for is_local in reversed(flags):
if is_local:
actor = self.local_engine_actors[li]
li -= 1
else:
actor = self.remote_engine_actors[ri]
ri -= 1
ref = self.actor_run_ref_dict.pop(actor)
self.run_refs.remove(ref)
def get_run_refs(self):
return self.run_refs
def monitor_engine_liveness(self) -> None:
import ray
while not self.manager_stopped.is_set():
actor_run_refs = list(self.get_run_refs())
if not actor_run_refs:
logger.info(
"There are no actors to monitor currently. "
"The monitoring function is about to terminate."
)
break
actor_done_refs, _ = ray.wait(actor_run_refs, timeout=5)
unexpected_failure = False
for actor_ref in actor_done_refs:
if self.manager_stopped.is_set():
break
if actor_ref not in self.get_run_refs():
# The run refs may have been updated by elastic scale-down.
continue
try:
ray.get(actor_ref)
except ray.exceptions.RayActorError:
self.failed_proc_name = f"Actor {actor_ref}"
unexpected_failure = True
if unexpected_failure:
break
self.shutdown()
def shutdown(self, timeout: float | None = None) -> None:
import ray
self.manager_stopped.set()
for actor in self.local_engine_actors + self.remote_engine_actors:
ray.kill(actor)
for pg in self.created_placement_groups:
......
......@@ -3,6 +3,7 @@
import argparse
import contextlib
import multiprocessing
import threading
import time
import weakref
from collections.abc import Callable, Sequence
......@@ -269,8 +270,6 @@ def wait_for_completion_or_failure(
coordinator: The coordinator for data parallel.
"""
from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
try:
logger.info("Waiting for API servers to complete ...")
# Create a mapping of sentinels to their corresponding processes
......@@ -282,33 +281,40 @@ def wait_for_completion_or_failure(
if coordinator:
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
actor_run_refs = []
if isinstance(engine_manager, CoreEngineProcManager):
for proc in engine_manager.processes:
sentinel_to_proc[proc.sentinel] = proc
elif isinstance(engine_manager, CoreEngineActorManager):
actor_run_refs = engine_manager.get_run_refs()
if engine_manager:
core_shutdown_recv, core_shutdown_send = connection.Pipe(duplex=False)
def monitor_engines():
try:
engine_manager.monitor_engine_liveness()
finally:
core_shutdown_send.close()
core_shutdown_recv.close()
# start monitor for engine liveness
threading.Thread(target=monitor_engines, daemon=True).start()
sentinel_to_proc[core_shutdown_recv] = None # type: ignore[assignment]
# Check if any process terminates
while sentinel_to_proc or actor_run_refs:
# Wait for any process to terminate
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5)
while sentinel_to_proc:
# Wait for any process to terminate (or engine shutdown signal)
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
# Process any terminated processes
for sentinel in ready_sentinels:
proc = sentinel_to_proc.pop(sentinel)
# Check if process exited with error
if proc.exitcode != 0:
if proc is not None and proc.exitcode != 0:
raise RuntimeError(
f"Process {proc.name} (PID: {proc.pid}) "
f"died with exit code {proc.exitcode}"
)
if actor_run_refs:
import ray
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
if engine_manager and engine_manager.failed_proc_name is not None:
raise RuntimeError(
f"Engine core process {engine_manager.failed_proc_name} "
"died unexpectedly."
)
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down API servers...")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment