"vscode:/vscode.git/clone" did not exist on "f91480b2d44c263fb600b5cba5b0e6c7a195f742"
Unverified Commit b5e60825 authored by fangyuchu's avatar fangyuchu Committed by GitHub
Browse files

[Refactor] Unify engine process monitoring in engine manager and add Ray backend support (#35862)


Signed-off-by: default avatarfangyuchu <fangyuchu@qq.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 2c734ed0
...@@ -225,7 +225,7 @@ def run_headless(args: argparse.Namespace): ...@@ -225,7 +225,7 @@ def run_headless(args: argparse.Namespace):
) )
try: try:
engine_manager.join_first() engine_manager.monitor_engine_liveness()
finally: finally:
timeout = None timeout = None
if shutdown_requested: if shutdown_requested:
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import contextlib import contextlib
import multiprocessing
import queue import queue
import sys import sys
import uuid import uuid
...@@ -640,34 +639,20 @@ class MPClient(EngineCoreClient): ...@@ -640,34 +639,20 @@ class MPClient(EngineCoreClient):
def start_engine_core_monitor(self): def start_engine_core_monitor(self):
"""Start a monitor thread for engine core processes.""" """Start a monitor thread for engine core processes."""
engine_manager = self.resources.engine_manager engine_manager = self.resources.engine_manager
if ( if engine_manager is None:
engine_manager is None
or not hasattr(engine_manager, "processes")
or not engine_manager.processes
):
# No engine processes to monitor # No engine processes to monitor
return return
engine_processes = engine_manager.processes
self_ref = weakref.ref(self) self_ref = weakref.ref(self)
# Monitor engine core process liveness. If any die unexpectedly, # Monitor engine core process liveness. If any die unexpectedly,
# logs an error, shuts down the client and invokes the failure # marks the engine as dead, and shuts down the client.
# callback to inform the engine.
def monitor_engine_cores(): def monitor_engine_cores():
sentinels = [proc.sentinel for proc in engine_processes] engine_manager.monitor_engine_liveness()
died = multiprocessing.connection.wait(sentinels)
_self = self_ref() _self = self_ref()
if not _self or not _self._finalizer.alive or _self.resources.engine_dead: if not _self or not _self._finalizer.alive or _self.resources.engine_dead:
return return
_self.resources.engine_dead = True _self.resources.engine_dead = True
proc_name = next(
proc.name for proc in engine_processes if proc.sentinel == died[0]
)
logger.error(
"Engine core proc %s died unexpectedly, shutting down client.",
proc_name,
)
_self.shutdown() _self.shutdown()
# Note: For MPClient, we don't have a failure callback mechanism # Note: For MPClient, we don't have a failure callback mechanism
# like MultiprocExecutor, but we set engine_dead flag which will # like MultiprocExecutor, but we set engine_dead flag which will
...@@ -1634,6 +1619,9 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1634,6 +1619,9 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap() ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap()
removed_dp_size = cur_data_parallel_size - new_data_parallel_size
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
self.resources.engine_manager.remove_run_refs_for_scale_down(removed_dp_size)
reconfig_futures = [] reconfig_futures = []
for cur_dp_rank, engine in enumerate(self.core_engines): for cur_dp_rank, engine in enumerate(self.core_engines):
reconfig_request = ReconfigureDistributedRequest( reconfig_request = ReconfigureDistributedRequest(
......
...@@ -11,7 +11,7 @@ from enum import Enum, auto ...@@ -11,7 +11,7 @@ from enum import Enum, auto
from multiprocessing import Process, connection from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue from multiprocessing.queues import Queue
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, cast
from unittest.mock import patch from unittest.mock import patch
import msgspec import msgspec
...@@ -133,6 +133,8 @@ class CoreEngineProcManager: ...@@ -133,6 +133,8 @@ class CoreEngineProcManager:
) )
self._finalizer = weakref.finalize(self, shutdown, self.processes) self._finalizer = weakref.finalize(self, shutdown, self.processes)
self.manager_stopped = threading.Event()
self.failed_proc_name: str | None = None
try: try:
for proc, local_dp_rank in zip(self.processes, local_dp_ranks): for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
...@@ -154,12 +156,31 @@ class CoreEngineProcManager: ...@@ -154,12 +156,31 @@ class CoreEngineProcManager:
def shutdown(self, timeout: float | None = None) -> None: def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown engine core processes with configurable timeout.""" """Shutdown engine core processes with configurable timeout."""
self.manager_stopped.set()
if self._finalizer.detach() is not None: if self._finalizer.detach() is not None:
shutdown(self.processes, timeout=timeout) shutdown(self.processes, timeout=timeout)
def join_first(self): def monitor_engine_liveness(self) -> None:
"""Wait for any process to exit.""" """Monitor engine core process liveness."""
connection.wait(proc.sentinel for proc in self.processes)
sentinel_to_proc = {proc.sentinel: proc for proc in self.processes}
sentinels = set(sentinel_to_proc.keys())
while sentinels and not self.manager_stopped.is_set():
died_sentinels = connection.wait(sentinels, timeout=1)
for sentinel in died_sentinels:
proc = sentinel_to_proc.pop(cast(int, sentinel))
exitcode = proc.exitcode
if exitcode != 0 and not self.manager_stopped.is_set():
self.failed_proc_name = proc.name
if died_sentinels:
# Any engine exit currently triggers a shutdown. Future
# work (e.g., Elastic and fault-tolerant EP) will add finer-grained
# handling for different exit scenarios.
break
self.shutdown()
def sentinels(self) -> list: def sentinels(self) -> list:
return [proc.sentinel for proc in self.processes] return [proc.sentinel for proc in self.processes]
...@@ -298,6 +319,8 @@ class CoreEngineActorManager: ...@@ -298,6 +319,8 @@ class CoreEngineActorManager:
self.log_stats = log_stats self.log_stats = log_stats
local_engine_count = vllm_config.parallel_config.data_parallel_size_local local_engine_count = vllm_config.parallel_config.data_parallel_size_local
world_size = vllm_config.parallel_config.world_size world_size = vllm_config.parallel_config.world_size
self.manager_stopped = threading.Event()
self.failed_proc_name: str | None = None
if ray.is_initialized(): if ray.is_initialized():
logger.info("Ray is already initialized. Skipping Ray initialization.") logger.info("Ray is already initialized. Skipping Ray initialization.")
...@@ -395,8 +418,11 @@ class CoreEngineActorManager: ...@@ -395,8 +418,11 @@ class CoreEngineActorManager:
ray.get(refs) ray.get(refs)
self.run_refs = [] self.run_refs = []
self.actor_run_ref_dict = dict()
for actor in self.local_engine_actors + self.remote_engine_actors: for actor in self.local_engine_actors + self.remote_engine_actors:
self.run_refs.append(actor.run.remote()) ref = actor.run.remote()
self.run_refs.append(ref)
self.actor_run_ref_dict[actor] = ref
@staticmethod @staticmethod
def create_dp_placement_groups( def create_dp_placement_groups(
...@@ -776,7 +802,9 @@ class CoreEngineActorManager: ...@@ -776,7 +802,9 @@ class CoreEngineActorManager:
) + self.remote_engine_actors[-(len(placement_groups) - new_local_engines) :] ) + self.remote_engine_actors[-(len(placement_groups) - new_local_engines) :]
for actor in actors: for actor in actors:
self.run_refs.append(actor.run.remote()) ref = actor.run.remote()
self.run_refs.append(ref)
self.actor_run_ref_dict[actor] = ref
cur_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size cur_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# Update old_vllm_config with new data_parallel_size_local if any new # Update old_vllm_config with new data_parallel_size_local if any new
...@@ -805,12 +833,59 @@ class CoreEngineActorManager: ...@@ -805,12 +833,59 @@ class CoreEngineActorManager:
self.remote_engine_actors.pop() self.remote_engine_actors.pop()
ray.util.remove_placement_group(pg) ray.util.remove_placement_group(pg)
def remove_run_refs_for_scale_down(self, removed_dp_size: int) -> None:
if removed_dp_size <= 0:
return
flags = self.placement_group_is_local[-removed_dp_size:]
li = len(self.local_engine_actors) - 1
ri = len(self.remote_engine_actors) - 1
for is_local in reversed(flags):
if is_local:
actor = self.local_engine_actors[li]
li -= 1
else:
actor = self.remote_engine_actors[ri]
ri -= 1
ref = self.actor_run_ref_dict.pop(actor)
self.run_refs.remove(ref)
def get_run_refs(self): def get_run_refs(self):
return self.run_refs return self.run_refs
def monitor_engine_liveness(self) -> None:
import ray
while not self.manager_stopped.is_set():
actor_run_refs = list(self.get_run_refs())
if not actor_run_refs:
logger.info(
"There are no actors to monitor currently. "
"The monitoring function is about to terminate."
)
break
actor_done_refs, _ = ray.wait(actor_run_refs, timeout=5)
unexpected_failure = False
for actor_ref in actor_done_refs:
if self.manager_stopped.is_set():
break
if actor_ref not in self.get_run_refs():
# The run refs may have been updated by elastic scale-down.
continue
try:
ray.get(actor_ref)
except ray.exceptions.RayActorError:
self.failed_proc_name = f"Actor {actor_ref}"
unexpected_failure = True
if unexpected_failure:
break
self.shutdown()
def shutdown(self, timeout: float | None = None) -> None: def shutdown(self, timeout: float | None = None) -> None:
import ray import ray
self.manager_stopped.set()
for actor in self.local_engine_actors + self.remote_engine_actors: for actor in self.local_engine_actors + self.remote_engine_actors:
ray.kill(actor) ray.kill(actor)
for pg in self.created_placement_groups: for pg in self.created_placement_groups:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import argparse import argparse
import contextlib import contextlib
import multiprocessing import multiprocessing
import threading
import time import time
import weakref import weakref
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
...@@ -269,8 +270,6 @@ def wait_for_completion_or_failure( ...@@ -269,8 +270,6 @@ def wait_for_completion_or_failure(
coordinator: The coordinator for data parallel. coordinator: The coordinator for data parallel.
""" """
from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
try: try:
logger.info("Waiting for API servers to complete ...") logger.info("Waiting for API servers to complete ...")
# Create a mapping of sentinels to their corresponding processes # Create a mapping of sentinels to their corresponding processes
...@@ -282,33 +281,40 @@ def wait_for_completion_or_failure( ...@@ -282,33 +281,40 @@ def wait_for_completion_or_failure(
if coordinator: if coordinator:
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
actor_run_refs = [] if engine_manager:
if isinstance(engine_manager, CoreEngineProcManager): core_shutdown_recv, core_shutdown_send = connection.Pipe(duplex=False)
for proc in engine_manager.processes:
sentinel_to_proc[proc.sentinel] = proc def monitor_engines():
elif isinstance(engine_manager, CoreEngineActorManager): try:
actor_run_refs = engine_manager.get_run_refs() engine_manager.monitor_engine_liveness()
finally:
core_shutdown_send.close()
core_shutdown_recv.close()
# start monitor for engine liveness
threading.Thread(target=monitor_engines, daemon=True).start()
sentinel_to_proc[core_shutdown_recv] = None # type: ignore[assignment]
# Check if any process terminates # Check if any process terminates
while sentinel_to_proc or actor_run_refs: while sentinel_to_proc:
# Wait for any process to terminate # Wait for any process to terminate (or engine shutdown signal)
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5) ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
# Process any terminated processes # Process any terminated processes
for sentinel in ready_sentinels: for sentinel in ready_sentinels:
proc = sentinel_to_proc.pop(sentinel) proc = sentinel_to_proc.pop(sentinel)
# Check if process exited with error # Check if process exited with error
if proc.exitcode != 0: if proc is not None and proc.exitcode != 0:
raise RuntimeError( raise RuntimeError(
f"Process {proc.name} (PID: {proc.pid}) " f"Process {proc.name} (PID: {proc.pid}) "
f"died with exit code {proc.exitcode}" f"died with exit code {proc.exitcode}"
) )
if engine_manager and engine_manager.failed_proc_name is not None:
if actor_run_refs: raise RuntimeError(
import ray f"Engine core process {engine_manager.failed_proc_name} "
"died unexpectedly."
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5) )
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down API servers...") logger.info("Received KeyboardInterrupt, shutting down API servers...")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment