Unverified Commit 8be263c3 authored by Walter Beller-Morales's avatar Walter Beller-Morales Committed by GitHub
Browse files

[Core] Cleanup shm based object store on engine shutdown (#32429)


Signed-off-by: default avatarwalterbm <walter.beller.morales@gmail.com>
parent e1a34c3a
...@@ -22,7 +22,7 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase): ...@@ -22,7 +22,7 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase):
def tearDown(self): def tearDown(self):
"""Clean up after tests""" """Clean up after tests"""
if self.ring_buffer: if self.ring_buffer:
del self.ring_buffer self.ring_buffer.close()
def test_buffer_opening(self): def test_buffer_opening(self):
"""Test opening an existing buffer""" """Test opening an existing buffer"""
......
...@@ -56,7 +56,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
def tearDown(self): def tearDown(self):
"""Clean up after each test.""" """Clean up after each test."""
if self.storage: if self.storage:
del self.storage self.storage.close()
def test_minimal_put_get_cycle(self): def test_minimal_put_get_cycle(self):
"""Test basic put and get operations.""" """Test basic put and get operations."""
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import pickle import pickle
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from contextlib import contextmanager from contextlib import contextmanager, suppress
from dataclasses import dataclass from dataclasses import dataclass
from itertools import chain from itertools import chain
from multiprocessing import shared_memory from multiprocessing import shared_memory
...@@ -126,6 +126,7 @@ class SingleWriterShmRingBuffer: ...@@ -126,6 +126,7 @@ class SingleWriterShmRingBuffer:
self.data_buffer_end = 0 self.data_buffer_end = 0
if create: if create:
logger.debug("Creating new shared memory buffer: %s", name)
# we are creating a buffer # we are creating a buffer
self.metadata: dict[int, int] = {} # monotonic_id -> start address self.metadata: dict[int, int] = {} # monotonic_id -> start address
self.shared_memory = shared_memory.SharedMemory( self.shared_memory = shared_memory.SharedMemory(
...@@ -169,11 +170,16 @@ class SingleWriterShmRingBuffer: ...@@ -169,11 +170,16 @@ class SingleWriterShmRingBuffer:
self.data_buffer_start = 0 self.data_buffer_start = 0
self.data_buffer_end = 0 self.data_buffer_end = 0
def __del__(self): def close(self) -> None:
"""Close the shared memory."""
if hasattr(self, "shared_memory"): if hasattr(self, "shared_memory"):
self.shared_memory.close() self.shared_memory.close()
if self.is_writer: if self.is_writer:
self.shared_memory.unlink() with suppress(FileNotFoundError):
self.shared_memory.unlink()
def __del__(self):
self.close()
def int2byte(self, integer: int) -> bytes: def int2byte(self, integer: int) -> bytes:
"""Convert an integer to bytes.""" """Convert an integer to bytes."""
...@@ -663,6 +669,10 @@ class SingleWriterShmObjectStorage: ...@@ -663,6 +669,10 @@ class SingleWriterShmObjectStorage:
if reader_count >= self.n_readers: if reader_count >= self.n_readers:
self.increment_reader_flag(data_view[: self.flag_bytes]) self.increment_reader_flag(data_view[: self.flag_bytes])
def close(self) -> None:
"""Close the shared memory."""
self.ring_buffer.close()
def handle(self): def handle(self):
"""Get handle for sharing across processes.""" """Get handle for sharing across processes."""
return ShmObjectStorageHandle( return ShmObjectStorageHandle(
......
...@@ -7,6 +7,7 @@ import logging ...@@ -7,6 +7,7 @@ import logging
import os import os
import sys import sys
import tempfile import tempfile
import uuid
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
...@@ -457,6 +458,27 @@ def get_vllm_port() -> int | None: ...@@ -457,6 +458,27 @@ def get_vllm_port() -> int | None:
raise ValueError(f"VLLM_PORT '{port}' must be a valid integer") from err raise ValueError(f"VLLM_PORT '{port}' must be a valid integer") from err
def get_env_or_set_default(
env_name: str,
default_factory: Callable[[], str],
) -> Callable[[], str]:
"""
Create a lambda that returns an environment variable value if set,
or generates and sets a default value using the provided factory function.
"""
def _get_or_set_default() -> str:
value = os.getenv(env_name)
if value is not None:
return value
default_value = default_factory()
os.environ[env_name] = default_value
return default_value
return _get_or_set_default
# The start-* and end* here are used by the documentation generator # The start-* and end* here are used by the documentation generator
# to extract the used env vars. # to extract the used env vars.
...@@ -1558,8 +1580,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1558,8 +1580,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
), ),
# Name of the shared memory buffer used for object storage. # Name of the shared memory buffer used for object storage.
# Only effective when mm_config.mm_processor_cache_type == "shm". # Only effective when mm_config.mm_processor_cache_type == "shm".
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": lambda: os.getenv( # Automatically generates a unique UUID-based name per process tree
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", "VLLM_OBJECT_STORAGE_SHM_BUFFER" # if not explicitly set.
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": get_env_or_set_default(
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
lambda: f"VLLM_OBJECT_STORAGE_SHM_BUFFER_{uuid.uuid4().hex}",
), ),
# The size in MB of the buffers (NVL and RDMA) used by DeepEP # The size in MB of the buffers (NVL and RDMA) used by DeepEP
"VLLM_DEEPEP_BUFFER_SIZE_MB": lambda: int( "VLLM_DEEPEP_BUFFER_SIZE_MB": lambda: int(
......
...@@ -295,6 +295,10 @@ class BaseMultiModalProcessorCache( ...@@ -295,6 +295,10 @@ class BaseMultiModalProcessorCache(
""" """
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
def close(self) -> None:
"""Close the underlying cache, if needed."""
pass
@abstractmethod @abstractmethod
def touch_sender_cache_item(self, mm_hash: str) -> None: def touch_sender_cache_item(self, mm_hash: str) -> None:
""" """
...@@ -534,6 +538,10 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): ...@@ -534,6 +538,10 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
def make_stats(self, *, delta: bool = False) -> CacheInfo: def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._stat(delta=delta) return self._stat(delta=delta)
@override
def close(self) -> None:
self._shm_cache.close()
def remove_dangling_items(self) -> None: def remove_dangling_items(self) -> None:
"""Remove items that are no longer in the shared memory cache.""" """Remove items that are no longer in the shared memory cache."""
cached_hashes = self._shm_cache.key_index.keys() cached_hashes = self._shm_cache.key_index.keys()
......
...@@ -249,6 +249,9 @@ class AsyncLLM(EngineClient): ...@@ -249,6 +249,9 @@ class AsyncLLM(EngineClient):
if engine_core := getattr(self, "engine_core", None): if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown() engine_core.shutdown()
if input_processor := getattr(self, "input_processor", None):
input_processor.close()
handler = getattr(self, "output_handler", None) handler = getattr(self, "output_handler", None)
if handler is not None: if handler is not None:
cancel_task_threadsafe(handler) cancel_task_threadsafe(handler)
......
...@@ -712,3 +712,7 @@ class InputProcessor: ...@@ -712,3 +712,7 @@ class InputProcessor:
def clear_mm_cache(self) -> None: def clear_mm_cache(self) -> None:
self.input_preprocessor.clear_mm_cache() self.input_preprocessor.clear_mm_cache()
def close(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment