Unverified Commit 8be263c3 authored by Walter Beller-Morales's avatar Walter Beller-Morales Committed by GitHub
Browse files

[Core] Cleanup shm based object store on engine shutdown (#32429)


Signed-off-by: default avatarwalterbm <walter.beller.morales@gmail.com>
parent e1a34c3a
......@@ -22,7 +22,7 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
def tearDown(self):
"""Clean up after tests"""
if self.ring_buffer:
del self.ring_buffer
self.ring_buffer.close()
def test_buffer_opening(self):
"""Test opening an existing buffer"""
......
......@@ -56,7 +56,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
def tearDown(self):
"""Clean up after each test."""
if self.storage:
del self.storage
self.storage.close()
def test_minimal_put_get_cycle(self):
"""Test basic put and get operations."""
......
......@@ -4,7 +4,7 @@
import pickle
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from dataclasses import dataclass
from itertools import chain
from multiprocessing import shared_memory
......@@ -126,6 +126,7 @@ class SingleWriterShmRingBuffer:
self.data_buffer_end = 0
if create:
logger.debug("Creating new shared memory buffer: %s", name)
# we are creating a buffer
self.metadata: dict[int, int] = {} # monotonic_id -> start address
self.shared_memory = shared_memory.SharedMemory(
......@@ -169,12 +170,17 @@ class SingleWriterShmRingBuffer:
self.data_buffer_start = 0
self.data_buffer_end = 0
def __del__(self):
def close(self) -> None:
"""Close the shared memory."""
if hasattr(self, "shared_memory"):
self.shared_memory.close()
if self.is_writer:
with suppress(FileNotFoundError):
self.shared_memory.unlink()
def __del__(self):
self.close()
def int2byte(self, integer: int) -> bytes:
"""Convert an integer to bytes."""
return integer.to_bytes(self.ID_NBYTES, "little", signed=True)
......@@ -663,6 +669,10 @@ class SingleWriterShmObjectStorage:
if reader_count >= self.n_readers:
self.increment_reader_flag(data_view[: self.flag_bytes])
def close(self) -> None:
"""Close the shared memory."""
self.ring_buffer.close()
def handle(self):
"""Get handle for sharing across processes."""
return ShmObjectStorageHandle(
......
......@@ -7,6 +7,7 @@ import logging
import os
import sys
import tempfile
import uuid
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal
......@@ -457,6 +458,27 @@ def get_vllm_port() -> int | None:
raise ValueError(f"VLLM_PORT '{port}' must be a valid integer") from err
def get_env_or_set_default(
env_name: str,
default_factory: Callable[[], str],
) -> Callable[[], str]:
"""
Create a lambda that returns an environment variable value if set,
or generates and sets a default value using the provided factory function.
"""
def _get_or_set_default() -> str:
value = os.getenv(env_name)
if value is not None:
return value
default_value = default_factory()
os.environ[env_name] = default_value
return default_value
return _get_or_set_default
# The start-* and end* here are used by the documentation generator
# to extract the used env vars.
......@@ -1558,8 +1580,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Name of the shared memory buffer used for object storage.
# Only effective when mm_config.mm_processor_cache_type == "shm".
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": lambda: os.getenv(
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", "VLLM_OBJECT_STORAGE_SHM_BUFFER"
# Automatically generates a unique UUID-based name per process tree
# if not explicitly set.
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": get_env_or_set_default(
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
lambda: f"VLLM_OBJECT_STORAGE_SHM_BUFFER_{uuid.uuid4().hex}",
),
# The size in MB of the buffers (NVL and RDMA) used by DeepEP
"VLLM_DEEPEP_BUFFER_SIZE_MB": lambda: int(
......
......@@ -295,6 +295,10 @@ class BaseMultiModalProcessorCache(
"""
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
def close(self) -> None:
"""Close the underlying cache, if needed."""
pass
@abstractmethod
def touch_sender_cache_item(self, mm_hash: str) -> None:
"""
......@@ -534,6 +538,10 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._stat(delta=delta)
@override
def close(self) -> None:
self._shm_cache.close()
def remove_dangling_items(self) -> None:
"""Remove items that are no longer in the shared memory cache."""
cached_hashes = self._shm_cache.key_index.keys()
......
......@@ -249,6 +249,9 @@ class AsyncLLM(EngineClient):
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
if input_processor := getattr(self, "input_processor", None):
input_processor.close()
handler = getattr(self, "output_handler", None)
if handler is not None:
cancel_task_threadsafe(handler)
......
......@@ -712,3 +712,7 @@ class InputProcessor:
def clear_mm_cache(self) -> None:
self.input_preprocessor.clear_mm_cache()
def close(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment