Unverified Commit 80c751e7 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[V1] Simplify Shutdown (#11659)

parent e1a5c2f0
...@@ -142,9 +142,6 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): ...@@ -142,9 +142,6 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client.abort_requests([request.request_id]) client.abort_requests([request.request_id])
# Shutdown the client.
client.shutdown()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_engine_core_client_asyncio(monkeypatch): async def test_engine_core_client_asyncio(monkeypatch):
...@@ -200,6 +197,3 @@ async def test_engine_core_client_asyncio(monkeypatch): ...@@ -200,6 +197,3 @@ async def test_engine_core_client_asyncio(monkeypatch):
else: else:
assert len(outputs[req_id]) == MAX_TOKENS, ( assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}") f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
# Shutdown the client.
client.shutdown()
...@@ -232,11 +232,6 @@ class LLM: ...@@ -232,11 +232,6 @@ class LLM:
self.request_counter = Counter() self.request_counter = Counter()
def __del__(self):
if hasattr(self, 'llm_engine') and self.llm_engine and hasattr(
self.llm_engine, "shutdown"):
self.llm_engine.shutdown()
@staticmethod @staticmethod
def get_engine_class() -> Type[LLMEngine]: def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
......
...@@ -103,9 +103,6 @@ class AsyncLLM(EngineClient): ...@@ -103,9 +103,6 @@ class AsyncLLM(EngineClient):
self.output_handler: Optional[asyncio.Task] = None self.output_handler: Optional[asyncio.Task] = None
def __del__(self):
self.shutdown()
@classmethod @classmethod
def from_engine_args( def from_engine_args(
cls, cls,
......
...@@ -203,7 +203,6 @@ class EngineCoreProc(EngineCore): ...@@ -203,7 +203,6 @@ class EngineCoreProc(EngineCore):
finally: finally:
if engine_core is not None: if engine_core is not None:
engine_core.shutdown() engine_core.shutdown()
engine_core = None
def run_busy_loop(self): def run_busy_loop(self):
"""Core busy loop of the EngineCore.""" """Core busy loop of the EngineCore."""
......
from typing import List, Optional, Type import weakref
from abc import ABC, abstractmethod
from typing import List, Type
import msgspec import msgspec
import zmq import zmq
...@@ -18,7 +20,7 @@ from vllm.v1.utils import BackgroundProcHandle ...@@ -18,7 +20,7 @@ from vllm.v1.utils import BackgroundProcHandle
logger = init_logger(__name__) logger = init_logger(__name__)
class EngineCoreClient: class EngineCoreClient(ABC):
""" """
EngineCoreClient: subclasses handle different methods for pushing EngineCoreClient: subclasses handle different methods for pushing
and pulling from the EngineCore for asyncio / multiprocessing. and pulling from the EngineCore for asyncio / multiprocessing.
...@@ -52,8 +54,9 @@ class EngineCoreClient: ...@@ -52,8 +54,9 @@ class EngineCoreClient:
return InprocClient(vllm_config, executor_class, log_stats) return InprocClient(vllm_config, executor_class, log_stats)
@abstractmethod
def shutdown(self): def shutdown(self):
pass ...
def get_output(self) -> List[EngineCoreOutput]: def get_output(self) -> List[EngineCoreOutput]:
raise NotImplementedError raise NotImplementedError
...@@ -107,9 +110,6 @@ class InprocClient(EngineCoreClient): ...@@ -107,9 +110,6 @@ class InprocClient(EngineCoreClient):
def shutdown(self): def shutdown(self):
self.engine_core.shutdown() self.engine_core.shutdown()
def __del__(self):
self.shutdown()
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start) self.engine_core.profile(is_start)
...@@ -139,10 +139,14 @@ class MPClient(EngineCoreClient): ...@@ -139,10 +139,14 @@ class MPClient(EngineCoreClient):
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
# ZMQ setup. # ZMQ setup.
if asyncio_mode: self.ctx = (
self.ctx = zmq.asyncio.Context() zmq.asyncio.Context() # type: ignore[attr-defined]
else: if asyncio_mode else zmq.Context()) # type: ignore[attr-defined]
self.ctx = zmq.Context() # type: ignore[attr-defined]
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
self._finalizer = weakref.finalize(self, lambda x: x.destroy(linger=0),
self.ctx)
# Paths and sockets for IPC. # Paths and sockets for IPC.
output_path = get_open_zmq_ipc_path() output_path = get_open_zmq_ipc_path()
...@@ -153,7 +157,6 @@ class MPClient(EngineCoreClient): ...@@ -153,7 +157,6 @@ class MPClient(EngineCoreClient):
zmq.constants.PUSH) zmq.constants.PUSH)
# Start EngineCore in background process. # Start EngineCore in background process.
self.proc_handle: Optional[BackgroundProcHandle]
self.proc_handle = BackgroundProcHandle( self.proc_handle = BackgroundProcHandle(
input_path=input_path, input_path=input_path,
output_path=output_path, output_path=output_path,
...@@ -166,12 +169,11 @@ class MPClient(EngineCoreClient): ...@@ -166,12 +169,11 @@ class MPClient(EngineCoreClient):
}) })
def shutdown(self): def shutdown(self):
# Shut down the zmq context. """Clean up background resources."""
self.ctx.destroy(linger=0) if hasattr(self, "proc_handle"):
if hasattr(self, "proc_handle") and self.proc_handle:
self.proc_handle.shutdown() self.proc_handle.shutdown()
self.proc_handle = None
self._finalizer()
class SyncMPClient(MPClient): class SyncMPClient(MPClient):
......
...@@ -205,10 +205,3 @@ class LLMEngine: ...@@ -205,10 +205,3 @@ class LLMEngine:
f"found type: {type(tokenizer_group)}") f"found type: {type(tokenizer_group)}")
return tokenizer_group return tokenizer_group
def __del__(self):
self.shutdown()
def shutdown(self):
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
import multiprocessing
import os import os
import weakref import weakref
from collections.abc import Sequence from collections.abc import Sequence
...@@ -91,8 +92,6 @@ class BackgroundProcHandle: ...@@ -91,8 +92,6 @@ class BackgroundProcHandle:
target_fn: Callable, target_fn: Callable,
process_kwargs: Dict[Any, Any], process_kwargs: Dict[Any, Any],
): ):
self._finalizer = weakref.finalize(self, self.shutdown)
context = get_mp_context() context = get_mp_context()
reader, writer = context.Pipe(duplex=False) reader, writer = context.Pipe(duplex=False)
...@@ -102,11 +101,11 @@ class BackgroundProcHandle: ...@@ -102,11 +101,11 @@ class BackgroundProcHandle:
process_kwargs["ready_pipe"] = writer process_kwargs["ready_pipe"] = writer
process_kwargs["input_path"] = input_path process_kwargs["input_path"] = input_path
process_kwargs["output_path"] = output_path process_kwargs["output_path"] = output_path
self.input_path = input_path
self.output_path = output_path
# Run Detokenizer busy loop in background process. # Run busy loop in background process.
self.proc = context.Process(target=target_fn, kwargs=process_kwargs) self.proc = context.Process(target=target_fn, kwargs=process_kwargs)
self._finalizer = weakref.finalize(self, shutdown, self.proc,
input_path, output_path)
self.proc.start() self.proc.start()
# Wait for startup. # Wait for startup.
...@@ -114,21 +113,24 @@ class BackgroundProcHandle: ...@@ -114,21 +113,24 @@ class BackgroundProcHandle:
raise RuntimeError(f"{process_name} initialization failed. " raise RuntimeError(f"{process_name} initialization failed. "
"See root cause above.") "See root cause above.")
def __del__(self):
self.shutdown()
def shutdown(self): def shutdown(self):
# Shutdown the process if needed. self._finalizer()
if hasattr(self, "proc") and self.proc.is_alive():
self.proc.terminate()
self.proc.join(5) # Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
if self.proc.is_alive(): def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
kill_process_tree(self.proc.pid) # Shutdown the process.
if proc.is_alive():
# Remove zmq ipc socket files proc.terminate()
ipc_sockets = [self.output_path, self.input_path] proc.join(5)
for ipc_socket in ipc_sockets:
socket_file = ipc_socket.replace("ipc://", "") if proc.is_alive():
if os and os.path.exists(socket_file): kill_process_tree(proc.pid)
os.remove(socket_file)
# Remove zmq ipc socket files.
ipc_sockets = [output_path, input_path]
for ipc_socket in ipc_sockets:
socket_file = ipc_socket.replace("ipc://", "")
if os and os.path.exists(socket_file):
os.remove(socket_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment