Unverified Commit a5b84f1c authored by dongluw's avatar dongluw Committed by GitHub
Browse files

[Core] Shared memory based object store for Multimodal data caching and IPC (#20452)


Signed-off-by: default avatardonglu <donglu@cohere.com>
parent 9f04d9d5
......@@ -789,6 +789,8 @@ steps:
commands:
- pytest -v -s distributed/test_comm_ops.py
- pytest -v -s distributed/test_shm_broadcast.py
- pytest -v -s distributed/test_shm_buffer.py
- pytest -v -s distributed/test_shm_storage.py
- label: 2 Node Tests (4 GPUs in total) # 16min
timeout_in_minutes: 30
......
......@@ -230,6 +230,20 @@ Multi-modal IPC caching is automatically enabled when
there is a one-to-one correspondence between API (`P0`) and engine core (`P1`) processes,
to avoid repeatedly transferring the same multi-modal inputs between them.
#### Key-Replicated Cache
By default, IPC caching uses a **key-replicated cache**, where cache keys exist
in both the API (`P0`) and engine core (`P1`) processes, but the actual cache
data resides only in `P1`.
#### Shared Memory Cache
When multiple worker processes are involved (e.g., when TP > 1), a
**shared-memory cache** is more efficient. This can be enabled by setting
`mm_processor_cache_type="shm"`. In this mode, cache keys are stored
on `P0`, while the cache data itself lives in shared memory accessible by all
processes.
### Configuration
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB).
......@@ -244,6 +258,12 @@ Examples:
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_cache_gb=8)
# Use a shared-memory based IPC cache
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
tensor_parallel_size=2,
mm_processor_cache_type="shm",
mm_processor_cache_gb=8)
# Disable the cache
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_cache_gb=0)
......@@ -253,11 +273,12 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows:
| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory |
|-------------------|-------------|------------|------------|-------------|
| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` |
| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` |
| ❌ | ❌ | N/A | N/A | `0` |
| mm_processor_cache_type | Cache Type | `P0` Cache | `P1` Engine Cache | `P1` Worker Cache | Max. Memory |
|-------------------|-------------|------------|------------|-------------|-------------|
| lru | Processor Caching | K + V | N/A | N/A | `mm_processor_cache_gb * data_parallel_size` |
| lru | Key-Replicated Caching | K | K + V | N/A | `mm_processor_cache_gb * api_server_count` |
| shm | Shared Memory Caching | K | N/A | V | `mm_processor_cache_gb * api_server_count` |
| N/A | Disabled | N/A | N/A | N/A | `0` |
K: Stores the hashes of multi-modal items
V: Stores the processed tensor data of multi-modal items
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import traceback
import unittest
from vllm.distributed.device_communicators.shm_object_storage import (
SingleWriterShmRingBuffer)
class TestSingleWriterShmRingBuffer(unittest.TestCase):
"""Test suite for the ring buffer implementation"""
def setUp(self):
"""Set up test fixtures"""
self.buffer_size = 4096
self.ring_buffer = None
def tearDown(self):
"""Clean up after tests"""
if self.ring_buffer:
del self.ring_buffer
def test_buffer_opening(self):
"""Test opening an existing buffer"""
# First create a buffer
self.ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=self.buffer_size, create=True)
# Then open it with another instance
reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle())
self.assertFalse(reader_buffer.is_writer)
self.assertEqual(reader_buffer.shared_memory.name,
self.ring_buffer.shared_memory.name)
def test_buffer_access(self):
"""Test accessing allocated buffers"""
self.ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=self.buffer_size, create=True)
size = 100
address, monotonic_id = self.ring_buffer.allocate_buf(size)
# Write some test data
test_data = b"Hello, World!" * 7 # 91 bytes
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
data_buf[0:len(test_data)] = test_data
# Read it back
with self.ring_buffer.access_buf(address) as (data_buf2, metadata2):
read_data = bytes(data_buf2[0:len(test_data)])
read_id = metadata2[0]
self.assertEqual(read_data, test_data)
self.assertEqual(read_id, monotonic_id)
def test_memory_error_on_full_buffer(self):
"""Test that MemoryError is raised when buffer is full"""
small_buffer_size = 200
self.ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=small_buffer_size, create=True)
# Fill up the buffer
self.ring_buffer.allocate_buf(100)
self.ring_buffer.allocate_buf(80) # Total: 196 bytes used
# This should fail
with self.assertRaises(MemoryError):
self.ring_buffer.allocate_buf(1) # Would exceed buffer capacity
def test_allocation_and_free(self):
"""Test allocation and freeing of buffers"""
small_buffer_size = 200
self.ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=small_buffer_size, create=True)
size = 80
# Write some data
test_data = b"Repeated test data"
for i in range(5):
address, monotonic_id = self.ring_buffer.allocate_buf(size)
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use
data_buf[4:len(test_data) + 4] = test_data
print(self.ring_buffer.metadata)
freed_ids = self.ring_buffer.free_buf(lambda *args: True)
print(f" Freed IDs: {freed_ids}")
self.assertEqual(freed_ids[0], i)
def test_clear_buffer(self):
"""Test clearing the buffer"""
self.ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=self.buffer_size, create=True)
# Allocate some buffers
for _ in range(3):
self.ring_buffer.allocate_buf(100)
# Clear the buffer
self.ring_buffer.clear()
# Check that metadata is empty and IDs reset
self.assertEqual(len(self.ring_buffer.metadata), 0)
self.assertEqual(self.ring_buffer.monotonic_id_start, 0)
self.assertEqual(self.ring_buffer.monotonic_id_end, 0)
self.assertEqual(self.ring_buffer.data_buffer_start, 0)
self.assertEqual(self.ring_buffer.data_buffer_end, 0)
def main():
"""Main function demonstrating usage and running tests"""
print("=== SingleWriterShmRingBuffer Test Suite ===\n")
# Run unit tests
print("Running unit tests...")
unittest.main(argv=[""], exit=False, verbosity=2)
print("\n" + "=" * 50)
print("=== Manual Demo ===\n")
# Manual demonstration
try:
print("Creating ring buffer...")
writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048,
create=True)
reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle())
print(f"Buffer created with name: {writer_buffer.shared_memory.name}")
# Allocate some buffers
print("\nAllocating buffers...")
address_array = []
for i in range(3):
size = 100 + i * 50
try:
writer_buffer.free_buf(lambda *args: True)
address, monotonic_id = writer_buffer.allocate_buf(size)
address_array.append((address, size, monotonic_id))
# Write some test data
with writer_buffer.access_buf(address) as (data_buf, metadata):
test_message = f"Test message {i}".encode()
data_buf[0:len(test_message)] = test_message
except MemoryError as e:
print(f" Failed to allocate {size} bytes: {e}")
print("\nBuffer state:")
print(f" Data buffer start: {writer_buffer.data_buffer_start}")
print(f" Data buffer end: {writer_buffer.data_buffer_end}")
print(f" Monotonic ID start: {writer_buffer.monotonic_id_start}")
print(f" Monotonic ID end: {writer_buffer.monotonic_id_end}")
print(f" Metadata entries: {len(writer_buffer.metadata)}")
# Try to read back the data
print("\nReading back data...")
for address, size, monotonic_id in address_array:
with reader_buffer.access_buf(address) as (data_buf, metadata):
# Find null terminator or read first 50 chars
data_bytes = bytes(data_buf[0:size])
message = data_bytes.decode()
print(f" ID {monotonic_id}: '{message}'")
except Exception as e:
print(f"Demo error: {e}")
traceback.print_exc()
print("\n=== Demo Complete ===")
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import random
import time
import traceback
import unittest
from multiprocessing import Lock
import torch
# Assuming these are imported from your module
from vllm.distributed.device_communicators.shm_object_storage import (
MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer)
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
MultiModalSharedField)
def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size, ), dtype=torch.int8),
field=MultiModalSharedField(1),
)
def _dummy_item(modality: str, size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
])
class TestSingleWriterShmObjectStorage(unittest.TestCase):
def setUp(self):
"""Set up test fixtures before each test method."""
ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=1024 * 100,
create=True, # 10 MB buffer
)
self.storage = SingleWriterShmObjectStorage(
max_object_size=1024 * 10, # 10KB max object
n_readers=2,
ring_buffer=ring_buffer,
serde_class=MsgpackSerde,
reader_lock=Lock(),
)
def tearDown(self):
"""Clean up after each test."""
if self.storage:
del self.storage
def test_minimal_put_get_cycle(self):
"""Test basic put and get operations."""
key = "test_key"
value = _dummy_item("text", {"field1": 10, "field2": 20})
# Put operation
address, monotonic_id = self.storage.put(key, value)
# Verify key is in index
self.assertIn(key, self.storage.key_index)
self.assertEqual(self.storage.key_index[key], (address, monotonic_id))
self.assertEqual(self.storage.id_index[monotonic_id], key)
# Get operation
result = self.storage.get(address, monotonic_id)
# Verify result
self.assertEqual(result, value)
def test_put_same_key_twice(self):
"""Test behavior when putting the same key multiple times."""
key = "duplicate_key"
value1 = "first value"
value2 = "second value"
# First put
address1, id1 = self.storage.put(key, value1)
retrieved1 = self.storage.get(address1, id1)
self.assertEqual(retrieved1, value1)
# should raise an error on second put
with self.assertRaises(ValueError) as context:
self.storage.put(key, value2)
self.assertIn("already exists in the storage", str(context.exception))
def test_large_object_rejection(self):
"""Test that objects exceeding max_object_size are rejected."""
# Create an object larger than max_object_size
large_data = "x" * (self.storage.max_object_size + 100)
with self.assertRaises(ValueError) as context:
self.storage.put("large_key", large_data)
self.assertIn("exceeds max object size", str(context.exception))
def test_buffer_overflow_and_cleanup(self):
"""Test behavior when buffer fills up and needs cleanup."""
# Fill up the buffer with many small objects
stored_items = []
try:
for i in range(1000): # Try to store many items
key = f"item_{i}"
value = f"data_{i}" * 100 # Make it reasonably sized
address, monotonic_id = self.storage.put(key, value)
stored_items.append((key, value, address, monotonic_id))
except MemoryError:
print(f"Buffer filled after {len(stored_items)} items")
# Verify that some items are still accessible
accessible_count = 0
for key, original_value, address, monotonic_id in stored_items:
for i in range(self.storage.n_readers):
retrieved = self.storage.get(address, monotonic_id)
if retrieved == original_value:
accessible_count += 1
self.assertEqual(accessible_count, len(stored_items))
try:
for i in range(len(stored_items), 1000): # Try to store many items
key = f"item_{i}"
value = f"data_{i}" * 100 # Make it reasonably sized
address, monotonic_id = self.storage.put(key, value)
stored_items.append((key, value, address, monotonic_id))
except MemoryError:
print(f"Buffer filled after {len(stored_items)} items")
# Verify that some items are still accessibles
for key, original_value, address, monotonic_id in stored_items:
try:
for i in range(self.storage.n_readers):
retrieved = self.storage.get(address, monotonic_id)
if retrieved == original_value:
accessible_count += 1
except ValueError as e:
print(f"Error retrieving {key}: {e}")
# some items from the first batch may still be accessible
self.assertGreaterEqual(accessible_count, len(stored_items))
def test_blocking_unread_object(self):
"""Test behavior when buffer fills up and needs cleanup."""
# Fill up the buffer with many small objects
stored_items = []
try:
for i in range(1000): # Try to store many items
key = f"item_{i}"
value = f"data_{i}" * 100 # Make it reasonably sized
address, monotonic_id = self.storage.put(key, value)
stored_items.append((key, value, address, monotonic_id))
except MemoryError:
print(f"Buffer filled after {len(stored_items)} items")
# read all items except the first one
# to simulate a blocking situation
accessible_count = 0
for key, original_value, address, monotonic_id in stored_items[1:]:
for i in range(self.storage.n_readers):
retrieved = self.storage.get(address, monotonic_id)
if retrieved == original_value:
accessible_count += 1
self.assertEqual(accessible_count, len(stored_items) - 1)
try:
key = f"item_{len(stored_items)}"
value = f"data_{len(stored_items)}" * 100
address, monotonic_id = self.storage.put(key, value)
except MemoryError:
print(f"Buffer filled after {len(stored_items)} items")
# read the first item
for i in range(self.storage.n_readers):
key, original_value, address, monotonic_id = stored_items[0]
retrieved = self.storage.get(address, monotonic_id)
self.assertEqual(retrieved, original_value)
try:
for i in range(len(stored_items), 1000): # Try to store many items
key = f"item_{i}"
value = f"data_{i}" * 100 # Make it reasonably sized
address, monotonic_id = self.storage.put(key, value)
stored_items.append((key, value, address, monotonic_id))
except MemoryError:
print(f"Buffer filled after {len(stored_items)} items")
# some items from the first batch may still be accessible
self.assertGreaterEqual(len(stored_items), accessible_count + 10)
def test_invalid_get_operations(self):
"""Test various invalid get operations."""
# Test with non-existent address
with self.assertRaises(ValueError): # Could be various exceptions
self.storage.get(99999, 1)
# Store something first
address, monotonic_id = self.storage.put("test", "value")
# Test with wrong monotonic_id
with self.assertRaises(ValueError) as context:
self.storage.get(address, monotonic_id + 100)
self.assertIn("has been modified or is invalid", \
str(context.exception))
def test_clear_storage(self):
"""Test clearing the storage."""
# Store some items
for i in range(5):
self.storage.put(f"item_{i}", f"value_{i}")
# Clear the storage
self.storage.clear()
# Verify that all indices are empty
self.assertEqual(len(self.storage.key_index), 0)
self.assertEqual(len(self.storage.id_index), 0)
self.assertEqual(len(self.storage.ring_buffer.metadata), 0)
# Verify that new items can be added after clearing
address, monotonic_id = self.storage.put("new_item", "new_value")
self.assertIn("new_item", self.storage.key_index)
self.assertEqual((address, monotonic_id), (0, 0))
# Reader process function
def reader_process(process_id, storage_handle, items_to_read):
"""Reader process that connects to existing shared memory and reads data."""
reader_storage = SingleWriterShmObjectStorage.create_from_handle(
storage_handle)
print(f"Reader {process_id} started")
errors = []
for key, original_value, address, monotonic_id in items_to_read:
time.sleep(random.random() / 100)
try:
# Read data from shared memory
retrieved_value = reader_storage.get(address, monotonic_id)
# Verify data integrity
assert retrieved_value == original_value
print(f"Reader {process_id} retrieved {key}: {retrieved_value}")
except Exception as e:
errors.append((key, str(e), type(e).__name__))
def run_multiprocess_example():
"""Run a minimal working example with real shared memory."""
print("=== Minimal Object Storage Example ===")
try:
# Create storage instance
ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=1024 * 100,
create=True, # 10 MB buffer
)
storage = SingleWriterShmObjectStorage(
max_object_size=1024,
n_readers=3,
ring_buffer=ring_buffer,
serde_class=MsgpackSerde,
reader_lock=Lock(),
)
print(f"Created storage (writer: {storage.is_writer})")
# Test basic data types
test_data = [
("user_data", {
"name": "Alice",
"age": 30,
"scores": [95, 87, 92]
}),
("simple_string", "Hello, World!"),
("number", 42),
("list_data", [1, 2, 3, "four", 5.0]),
]
stored_items = []
# Store all data
for key, value in test_data:
print(f"Storing {key}: {value}")
address, monotonic_id = storage.put(key, value)
stored_items.append((key, value, address, monotonic_id))
print(f" -> Stored at address {address}, ID {monotonic_id}")
print("\n--- Retrieving Data ---")
processes = []
handle = storage.handle()
# initialize lock for reader processes
handle.reader_lock = Lock()
for i in range(storage.n_readers):
p = multiprocessing.Process(target=reader_process,
args=(i, handle, stored_items))
processes.append(p)
p.start()
for p in processes:
p.join(timeout=10)
if p.is_alive():
p.terminate()
p.join()
except Exception as e:
print(f"Error in minimal example: {e}")
traceback.print_exc()
if __name__ == "__main__":
# Run the minimal example first
run_multiprocess_example()
print("\n" + "=" * 50 + "\n")
# Run the test suite
print("Running comprehensive test suite...")
unittest.main(verbosity=2, exit=False)
......@@ -10,8 +10,8 @@ from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from vllm.multimodal.cache import (MultiModalCache,
MultiModalProcessorCacheItem,
MultiModalProcessorCacheItemMetadata,
processor_cache_from_config,
receiver_cache_from_config)
engine_receiver_cache_from_config,
processor_cache_from_config)
from vllm.multimodal.hasher import MultiModalHasher
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
MultiModalKwargsItems,
......@@ -115,9 +115,9 @@ def _compare_caches(
):
mm_registry = MultiModalRegistry()
cache_0_p0 = processor_cache_from_config(config_0, mm_registry)
cache_0_p1 = receiver_cache_from_config(config_0, mm_registry)
cache_0_p1 = engine_receiver_cache_from_config(config_0, mm_registry)
cache_1_p0 = processor_cache_from_config(config_1, mm_registry)
cache_1_p1 = receiver_cache_from_config(config_1, mm_registry)
cache_1_p1 = engine_receiver_cache_from_config(config_1, mm_registry)
cache_size_gb = max(
config_0.model_config.mm_processor_cache_gb,
......
......@@ -90,6 +90,7 @@ class DummyExecutor(UniProcExecutor):
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.mm_receiver_cache = None
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
......
......@@ -39,6 +39,7 @@ ALLOWED_FILES = set([
'vllm/engine/multiprocessing/client.py',
'vllm/distributed/device_communicators/all_reduce_utils.py',
'vllm/distributed/device_communicators/shm_broadcast.py',
'vllm/distributed/device_communicators/shm_object_storage.py',
'vllm/engine/multiprocessing/engine.py',
'benchmarks/kernels/graph_machete_bench.py',
'benchmarks/kernels/benchmark_lora.py',
......
......@@ -262,6 +262,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
MMEncoderTPMode = Literal["weights", "data"]
MMCacheType = Literal["shm", "lru"]
class LogprobsMode(enum.Enum):
......@@ -450,6 +451,13 @@ class ModelConfig:
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
Set to `0` to disable this cache completely (not recommended)."""
mm_processor_cache_type: MMCacheType = "lru"
"""Type of cache to use for the multi-modal preprocessor/mapper. If `shm`,
use shared memory FIFO cache. If `lru`, use mirrored LRU cache."""
mm_shm_cache_max_object_size_mb: int = 128
"""Size limit (in MiB) for each object stored in the multi-modal processor
shared memory cache. Only effective when `mm_processor_cache_type` is
`"shm"`."""
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
......@@ -881,6 +889,9 @@ class ModelConfig:
media_io_kwargs=self.media_io_kwargs,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_processor_cache_type=self.mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=self.
mm_shm_cache_max_object_size_mb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
interleave_mm_strings=self.interleave_mm_strings,
skip_mm_profiling=self.skip_mm_profiling,
......@@ -2448,6 +2459,15 @@ class MultiModalConfig:
Set to `0` to disable this cache completely (not recommended).
"""
mm_processor_cache_type: MMCacheType = "lru"
"""Type of cache to use for the multi-modal preprocessor/mapper. If `shm`,
use shared memory FIFO cache. If `lru`, use mirrored LRU cache."""
mm_shm_cache_max_object_size_mb: int = 128
"""Size limit (in MiB) for each object stored in the multi-modal processor
shared memory cache. Only effective when `mm_processor_cache_type` is
`"shm"`."""
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""
Indicates how to optimize multi-modal encoder inference using
......
This diff is collapsed.
......@@ -27,8 +27,8 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DistributedExecutorBackend, EPLBConfig,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
ModelDType, ModelImpl, MultiModalConfig,
LoRAConfig, MambaDType, MMCacheType, MMEncoderTPMode,
ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
SchedulerPolicy, SpeculativeConfig, TaskOption,
......@@ -373,6 +373,10 @@ class EngineArgs:
MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
mm_processor_cache_type: Optional[MMCacheType] = \
MultiModalConfig.mm_processor_cache_type
mm_shm_cache_max_object_size_mb: int = \
MultiModalConfig.mm_shm_cache_max_object_size_mb
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
io_processor_plugin: Optional[str] = None
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
......@@ -782,6 +786,12 @@ class EngineArgs:
multimodal_group.add_argument("--disable-mm-preprocessor-cache",
action="store_true",
deprecated=True)
multimodal_group.add_argument(
"--mm-processor-cache-type",
**multimodal_kwargs["mm_processor_cache_type"])
multimodal_group.add_argument(
"--mm-shm-cache-max-object-size-mb",
**multimodal_kwargs["mm_shm_cache_max_object_size_mb"])
multimodal_group.add_argument(
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"])
multimodal_group.add_argument(
......@@ -998,6 +1008,9 @@ class EngineArgs:
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_processor_cache_type=self.mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=self.
mm_shm_cache_max_object_size_mb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
......
......@@ -175,6 +175,7 @@ if TYPE_CHECKING:
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
def get_default_cache_root():
......@@ -1241,6 +1242,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# raw bytes. Defaults to True for backward compatibility.
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))),
# Name of the shared memory buffer used for object storage.
# Only effective when mm_config.mm_processor_cache_type == "shm".
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME":
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
}
# --8<-- [end:env-vars-definition]
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from multiprocessing import Lock
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
......@@ -10,9 +11,12 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
run_method)
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
......@@ -44,6 +48,8 @@ class UniProcExecutor(ExecutorBase):
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
......@@ -55,6 +61,8 @@ class UniProcExecutor(ExecutorBase):
kwargs: Optional[Dict] = None) -> List[Any]:
if kwargs is None:
kwargs = {}
if self.mm_receiver_cache is not None and method == "execute_model":
get_and_update_mm_cache(self.mm_receiver_cache, args)
answer = run_method(self.driver_worker, method, args, kwargs)
return [answer]
......@@ -128,6 +136,8 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
......
......@@ -3,19 +3,24 @@
import sys
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, cast
import torch
from typing_extensions import TypeAlias, override
from vllm.distributed.device_communicators.shm_object_storage import (
MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer)
from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache
from vllm.utils import GiB_bytes, LRUCache, MiB_bytes
from vllm.utils.jsontree import (json_count_leaves, json_map_leaves,
json_reduce_leaves)
from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalKwargsItems, NestedTensors)
from .inputs import (MultiModalBatchedField, MultiModalFeatureSpec,
MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItems,
NestedTensors)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
......@@ -389,6 +394,106 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
self._cache.clear()
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is enabled.
How to update each item:
- If the item is already in the cache, clear the input to avoid
unnecessary IPC.
- If the item is not in the cache, store the data in shared memory.
"""
def __init__(self, vllm_config: "VllmConfig") -> None:
super().__init__()
self.world_size = vllm_config.parallel_config.world_size
mm_config = vllm_config.model_config.get_multimodal_config()
ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
create=True, # sender is the writer
)
self._shm_cache = SingleWriterShmObjectStorage(
max_object_size=mm_config.mm_shm_cache_max_object_size_mb *
MiB_bytes,
n_readers=self.world_size,
ring_buffer=ring_buffer,
serde_class=MsgpackSerde,
)
# cache (prompt_updates, modality) for P0 only
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate],
str]] = {}
@override
def is_cached_item(self, mm_hash: str) -> bool:
return self._shm_cache.is_cached(mm_hash)
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if self._shm_cache.is_cached(mm_hash):
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
prompt_updates, modality = self._p0_cache[mm_hash]
return self.address_as_item(address, monotonic_id,
modality), prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
try:
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
# Try to remove dangling items if p0 cache is too large.
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
self.remove_dangling_items()
self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
address_item = self.address_as_item(address, monotonic_id,
mm_item[0].modality)
return address_item, mm_item[1]
except (ValueError, MemoryError) as e:
# put may fail if the object is too large or
# the cache is full.
# In this case we log the error and keep the original mm_input.
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash,
e)
return mm_item
@override
def clear_cache(self) -> None:
self._shm_cache.clear()
self._p0_cache.clear()
def remove_dangling_items(self) -> None:
"""Remove items that are no longer in the shared memory cache."""
cached_hashes = self._shm_cache.key_index.keys()
dangling_hashes = set(self._p0_cache.keys()) - cached_hashes
for mm_hash in dangling_hashes:
del self._p0_cache[mm_hash]
def address_as_item(self, address: int, monotonic_id: int,
modality: str) -> MultiModalKwargsItem:
addr_elem = MultiModalFieldElem(
modality=modality,
key="address",
data=address,
field=MultiModalBatchedField(),
)
id_elem = MultiModalFieldElem(
modality=modality,
key="monotonic_id",
data=monotonic_id,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
return mm_item
def _enable_processor_cache(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
......@@ -408,6 +513,17 @@ def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
return supports_ipc_cache
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
"""Whether the shared memory based cache should be enabled."""
if not _enable_ipc_cache(vllm_config):
return False
mm_config = vllm_config.model_config.get_multimodal_config()
return mm_config.mm_processor_cache_type == "shm"
def processor_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
......@@ -421,7 +537,9 @@ def processor_cache_from_config(
if not _enable_ipc_cache(vllm_config):
return MultiModalProcessorOnlyCache(model_config)
return MultiModalProcessorSenderCache(model_config)
if not _enable_mm_input_shm_cache(vllm_config):
return MultiModalProcessorSenderCache(model_config)
return ShmObjectStoreSenderCache(vllm_config)
def processor_only_cache_from_config(
......@@ -491,11 +609,68 @@ class MultiModalReceiverCache(BaseMultiModalReceiverCache):
self._cache.clear()
def receiver_cache_from_config(
class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
"""
The cache which is used on P1 Worker Process when IPC caching is enabled.
How to update each item:
- If the item has an address, replace the input with the cached item.
- If not, return the input.
"""
def __init__(
self,
vllm_config: "VllmConfig",
shared_worker_lock: LockType,
) -> None:
super().__init__()
self.world_size = vllm_config.parallel_config.world_size
mm_config = vllm_config.model_config.get_multimodal_config()
ring_buffer = SingleWriterShmRingBuffer(
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
create=False, # Server is a reader
)
self._shm_cache = SingleWriterShmObjectStorage(
max_object_size=mm_config.mm_shm_cache_max_object_size_mb *
MiB_bytes,
n_readers=self.world_size,
ring_buffer=ring_buffer,
serde_class=MsgpackSerde,
reader_lock=shared_worker_lock,
)
@override
def get_and_update_item(
self,
mm_item: Optional[MultiModalKwargsItem],
mm_hash: str,
) -> MultiModalKwargsItem:
assert mm_item is not None, f"Expected an address item for {mm_hash=}"
if "address" in mm_item:
address = cast(int, mm_item["address"].data)
monotonic_id = cast(int, mm_item["monotonic_id"].data)
return self._shm_cache.get(address, monotonic_id)
return mm_item
@override
def clear_cache(self) -> None:
self._shm_cache.clear()
def engine_receiver_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> Optional[BaseMultiModalReceiverCache]:
"""Return a `BaseMultiModalReceiverCache`, if enabled."""
"""
This is used in the engine process.
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
mm_processor_cache_type=="lru".
"""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
......@@ -504,4 +679,31 @@ def receiver_cache_from_config(
if not _enable_ipc_cache(vllm_config):
return None
return MultiModalReceiverCache(model_config)
if not _enable_mm_input_shm_cache(vllm_config):
return MultiModalReceiverCache(model_config)
return None
def worker_receiver_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
shared_worker_lock: LockType,
) -> Optional[BaseMultiModalReceiverCache]:
"""
This is used in the worker process.
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
mm_processor_cache_type=="shm".
"""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return None
if not _enable_mm_input_shm_cache(vllm_config):
return None
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
......@@ -163,6 +163,12 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
MB_bytes = 1_000_000
"""The number of bytes in one megabyte (MB)."""
MiB_bytes = 1 << 20
"""The number of bytes in one mebibyte (MiB)."""
GB_bytes = 1_000_000_000
"""The number of bytes in one gigabyte (GB)."""
......
......@@ -23,7 +23,7 @@ from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import receiver_cache_from_config
from vllm.multimodal.cache import engine_receiver_cache_from_config
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
......@@ -131,7 +131,7 @@ class EngineCore:
self.use_spec_decode = vllm_config.speculative_config is not None
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
self.mm_receiver_cache = receiver_cache_from_config(
self.mm_receiver_cache = engine_receiver_cache_from_config(
vllm_config, mm_registry)
# Setup batch queue for pipeline parallelism.
......
......@@ -14,6 +14,7 @@ from enum import Enum, auto
from functools import partial
from multiprocessing.connection import Connection
from multiprocessing.process import BaseProcess
from multiprocessing.synchronize import Lock as LockType
from threading import Thread
from typing import Any, Callable, Optional, Union, cast
......@@ -31,10 +32,13 @@ from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
from vllm.executor.multiproc_worker_utils import (
set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import (decorate_logs, get_distributed_init_method,
get_loopback_ip, get_mp_context, get_open_port,
set_process_title)
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
ModelRunnerOutput)
from vllm.worker.worker_base import WorkerWrapperBase
......@@ -81,6 +85,8 @@ class MultiprocExecutor(Executor):
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
context = get_mp_context()
shared_worker_lock = context.Lock()
unready_workers: list[UnreadyWorkerProcHandle] = []
success = False
try:
......@@ -92,6 +98,7 @@ class MultiprocExecutor(Executor):
rank=rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock,
))
# Workers must be created before wait_for_ready to avoid
......@@ -380,6 +387,7 @@ class WorkerProc:
rank: int,
distributed_init_method: str,
input_shm_handle: Handle,
shared_worker_lock: LockType,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
......@@ -416,6 +424,10 @@ class WorkerProc:
name="WorkerAsyncOutputCopy")
self.async_output_copy_thread.start()
# Initialize multimodal receiver cache if needed
self.mm_receiver_cache = worker_receiver_cache_from_config(
vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock)
# Initialize device
self.worker.init_device()
......@@ -428,11 +440,12 @@ class WorkerProc:
@staticmethod
def make_worker_process(
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput
shared_worker_lock: LockType,
) -> UnreadyWorkerProcHandle:
context = get_mp_context()
# (reader, writer)
......@@ -449,6 +462,7 @@ class WorkerProc:
"input_shm_handle": input_shm_handle,
"ready_pipe": (reader, writer),
"death_pipe": death_reader,
"shared_worker_lock": shared_worker_lock,
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=WorkerProc.worker_main,
......@@ -646,6 +660,10 @@ class WorkerProc:
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
# retrieve from shm cache if available
if self.mm_receiver_cache is not None \
and func.__name__ == "execute_model":
get_and_update_mm_cache(self.mm_receiver_cache, args)
output = func(*args, **kwargs)
except Exception as e:
# Notes have been introduced in python 3.11
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.multimodal.cache import ShmObjectStoreReceiverCache
from vllm.v1.core.sched.output import SchedulerOutput
def get_and_update_mm_cache(
receiver_cache: ShmObjectStoreReceiverCache,
args: tuple[SchedulerOutput],
) -> None:
"""
For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory
cache as needed.
Args:
receiver_cache: The receiver cache to update.
args: According to the collective_rpc call of execute_model method in
executor, args is a tuple of only one SchedulerOutput element.
"""
scheduler_output = args[0]
for request_data in scheduler_output.scheduled_new_reqs:
for i in range(len(request_data.mm_kwargs)):
mm_input = request_data.mm_kwargs[i]
request_data.mm_kwargs[i] = \
receiver_cache.get_and_update_item(mm_input, None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment