Unverified Commit 80b70884 authored by Brandon Pelfrey's avatar Brandon Pelfrey Committed by GitHub
Browse files

Add tensor IPC transfer mechanism for multimodal data (#32104)


Signed-off-by: default avatarBrandon Pelfrey <bpelfrey@nvidia.com>
Signed-off-by: default avatarBrandon Pelfrey <brandonpelfrey@gmail.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 61e381dc
......@@ -278,3 +278,148 @@ def test_custom_class_serialization_disallowed_without_pickle():
with pytest.raises(TypeError):
# Attempt to encode the custom class
encoder.encode(obj)
@dataclass
class RequestWithTensor:
"""Mock request with non-multimodal tensor field like EngineCoreRequest."""
prompt_embeds: torch.Tensor | None
data: str
def test_non_multimodal_tensor_with_ipc():
"""Test that non-multimodal tensor fields work correctly with IPC enabled.
This reproduces the bug where fields like prompt_embeds: torch.Tensor | None
would fail to decode when IPC is enabled because _decode_tensor expected a
raw tensor tuple but received a msgpack-decoded TensorIpcHandle list.
"""
import torch.multiprocessing as torch_mp
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
# Create tensor queues for IPC
tensor_queues = [torch_mp.Queue()]
# Create encoder with IPC sender
sender = TensorIpcSender(tensor_queues[0])
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
# Create decoder with IPC receiver
receiver = TensorIpcReceiver(tensor_queues[0])
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
# Create a request with a non-multimodal tensor
original_tensor = torch.randn(5, 10, dtype=torch.float32)
request = RequestWithTensor(prompt_embeds=original_tensor, data="test_data")
# Encode the request - this should send the tensor via IPC
encoded = encoder.encode(request)
# Verify encoding succeeded
assert len(encoded) > 0
# Decode the request - this should retrieve the tensor from IPC queue
# Previously this would fail because the decoder tried to unpack the
# handle list as raw tensor bytes metadata.
decoded = decoder.decode(encoded)
# Verify the decoded request matches the original
assert isinstance(decoded, RequestWithTensor)
assert decoded.data == "test_data"
assert decoded.prompt_embeds is not None
assert torch.allclose(decoded.prompt_embeds, original_tensor), (
"Decoded tensor does not match the original tensor."
)
def test_non_multimodal_tensor_with_ipc_none_value():
"""Test that None values for tensor fields work correctly with IPC enabled."""
import torch.multiprocessing as torch_mp
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
# Create tensor queues for IPC
tensor_queues = [torch_mp.Queue()]
# Create encoder with IPC sender
sender = TensorIpcSender(tensor_queues[0])
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
# Create decoder with IPC receiver
receiver = TensorIpcReceiver(tensor_queues[0])
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
# Create a request with None for the tensor field
request = RequestWithTensor(prompt_embeds=None, data="test_data_with_none")
# Encode and decode the request
encoded = encoder.encode(request)
decoded = decoder.decode(encoded)
# Verify the decoded request matches the original
assert isinstance(decoded, RequestWithTensor)
assert decoded.data == "test_data_with_none"
assert decoded.prompt_embeds is None
def test_multiple_senders_single_receiver_ipc():
"""Test N senders sharing a queue with a single receiver via msgpack.
Simulates the real vLLM topology where multiple API server frontends
each have their own MsgpackEncoder + TensorIpcSender, all putting
tensors onto the same torch.mp queue, and a single engine core
decodes them with one MsgpackDecoder + TensorIpcReceiver.
"""
import torch.multiprocessing as torch_mp
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
num_senders = 3
num_messages_per_sender = 2
tensor_queue = torch_mp.Queue()
# Create N independent senders (each gets its own uuid-based sender_id)
senders = []
encoders = []
for _ in range(num_senders):
s = TensorIpcSender(tensor_queue)
senders.append(s)
encoders.append(MsgpackEncoder(oob_tensor_consumer=s))
# Single receiver
receiver = TensorIpcReceiver(tensor_queue)
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
# Encode messages from all senders, interleaving the order
# so that tensors from different senders land on the queue interleaved.
encoded_payloads: list[tuple[int, int, torch.Tensor, list]] = []
for msg_idx in range(num_messages_per_sender):
for sender_idx in range(num_senders):
tensor = torch.full(
(sender_idx + 1, msg_idx + 2),
float(sender_idx * 100 + msg_idx),
dtype=torch.float32,
)
req = RequestWithTensor(
prompt_embeds=tensor,
data=f"s{sender_idx}_m{msg_idx}",
)
encoded = encoders[sender_idx].encode(req)
encoded_payloads.append((sender_idx, msg_idx, tensor, encoded))
# Decode all messages — the receiver must correctly match each
# tensor handle to the right TensorIpcData from the shared queue.
for sender_idx, msg_idx, original_tensor, encoded in encoded_payloads:
decoded = decoder.decode(encoded)
assert isinstance(decoded, RequestWithTensor)
assert decoded.data == f"s{sender_idx}_m{msg_idx}"
assert decoded.prompt_embeds is not None
assert decoded.prompt_embeds.shape == original_tensor.shape, (
f"Shape mismatch for sender {sender_idx} msg {msg_idx}: "
f"{decoded.prompt_embeds.shape} != {original_tensor.shape}"
)
assert torch.allclose(decoded.prompt_embeds, original_tensor), (
f"Value mismatch for sender {sender_idx} msg {msg_idx}"
)
This diff is collapsed.
......@@ -14,7 +14,12 @@ import vllm.envs as envs
from vllm.config.model_arch import (
ModelArchitectureConfig,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.multimodal import (
MMCacheType,
MMEncoderTPMode,
MMTensorIPC,
MultiModalConfig,
)
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
from vllm.config.utils import config, getattr_iter
......@@ -310,6 +315,7 @@ class ModelConfig:
interleave_mm_strings: InitVar[bool | None] = None
skip_mm_profiling: InitVar[bool | None] = None
video_pruning_rate: InitVar[float | None] = None
mm_tensor_ipc: InitVar[MMTensorIPC] = None
def compute_hash(self) -> str:
"""
......@@ -430,6 +436,7 @@ class ModelConfig:
interleave_mm_strings: bool | None,
skip_mm_profiling: bool | None,
video_pruning_rate: float | None,
mm_tensor_ipc: MMTensorIPC,
) -> None:
# Keep set served_model_name before maybe_model_redirect(self.model)
self.served_model_name = get_served_model_name(
......@@ -612,6 +619,7 @@ class ModelConfig:
interleave_mm_strings=interleave_mm_strings,
skip_mm_profiling=skip_mm_profiling,
video_pruning_rate=video_pruning_rate,
mm_tensor_ipc=mm_tensor_ipc,
)
mm_config_kwargs = {
......@@ -1112,6 +1120,22 @@ class ModelConfig:
f"({parallel_config.decode_context_parallel_size})."
)
# torch_shm uses a single IPC queue to rank 0; DP>1 is
# incompatible because API servers can't know which
# CoreEngine the scheduler will assign work to. TP>1 is
# also not supported because this requires broadcasting
# MM tensors between all TP ranks.
if (
self.multimodal_config is not None
and self.multimodal_config.mm_tensor_ipc == "torch_shm"
and parallel_config.world_size_across_dp > 1
):
raise ValueError(
"mm_tensor_ipc='torch_shm' is not supported with "
"data_parallel_size > 1 or tensor_parallel_size > 1 "
"or pipeline_parallel_size > 1."
)
def get_sliding_window(self) -> int | None:
"""Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None)
......
......@@ -59,6 +59,7 @@ class MultiModalDummyOptionsBuiltins(TypedDict, total=False):
MMEncoderTPMode = Literal["weights", "data"]
MMCacheType = Literal["shm", "lru"]
MMTensorIPC = Literal["direct_rpc", "torch_shm"]
MMDummyOptions: TypeAlias = dict[str, BaseDummyOptions]
"""
A dictionary containing an entry for each modality type of dummy data.
......@@ -172,6 +173,11 @@ class MultiModalConfig:
Value sits in range [0;1) and determines fraction of media tokens
from each video to be pruned.
"""
mm_tensor_ipc: MMTensorIPC = "direct_rpc"
"""IPC (inter-process communication) method for multimodal tensors.
- "direct_rpc": Use msgspec serialization via RPC
- "torch_shm": Use torch.multiprocessing shared memory for zero-copy IPC
Defaults to "direct_rpc". """
@field_validator("limit_per_prompt", mode="before")
@classmethod
......
......@@ -766,6 +766,17 @@ class VllmConfig:
else:
self.parallel_config.disable_nccl_for_dp_synchronization = False
if (
self.model_config is not None
and self.model_config.multimodal_config is not None
and self.model_config.multimodal_config.mm_tensor_ipc == "torch_shm"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
raise ValueError(
"torch_shm is known to fail without "
"VLLM_WORKER_MULTIPROC_METHOD set to spawn"
)
from vllm.platforms import current_platform
if (
......
......@@ -79,7 +79,7 @@ from vllm.config.model import (
RunnerOption,
TokenizerMode,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MMTensorIPC
from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import (
All2AllBackend,
......@@ -509,6 +509,7 @@ class EngineArgs:
io_processor_plugin: str | None = None
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
mm_tensor_ipc: MMTensorIPC = MultiModalConfig.mm_tensor_ipc
# LoRA fields
enable_lora: bool = False
max_loras: int = LoRAConfig.max_loras
......@@ -1097,6 +1098,9 @@ class EngineArgs:
multimodal_group.add_argument(
"--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
)
multimodal_group.add_argument(
"--mm-tensor-ipc", **multimodal_kwargs["mm_tensor_ipc"]
)
# LoRA related configs
lora_kwargs = get_kwargs(LoRAConfig)
......@@ -1423,6 +1427,7 @@ class EngineArgs:
override_attention_dtype=self.override_attention_dtype,
logits_processors=self.logits_processors,
video_pruning_rate=self.video_pruning_rate,
mm_tensor_ipc=self.mm_tensor_ipc,
io_processor_plugin=self.io_processor_plugin,
)
......
......@@ -290,7 +290,7 @@ def run_multi_api_server(args: argparse.Namespace):
with launch_core_engines(
vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses):
) as (local_engine_manager, coordinator, addresses, tensor_queue):
# Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict(
target_server_fn=run_api_server_worker_proc,
......@@ -303,6 +303,7 @@ def run_multi_api_server(args: argparse.Namespace):
stats_update_address=coordinator.get_stats_publish_address()
if coordinator
else None,
tensor_queue=tensor_queue,
)
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
......
......@@ -13,6 +13,7 @@ from enum import IntEnum
from functools import partial
from inspect import isclass, signature
from logging import DEBUG
from multiprocessing.queues import Queue
from typing import Any, TypeVar, cast
import msgspec
......@@ -59,6 +60,7 @@ from vllm.v1.engine import (
UtilityOutput,
UtilityResult,
)
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver
from vllm.v1.engine.utils import (
EngineHandshakeMetadata,
EngineZmqAddresses,
......@@ -788,6 +790,7 @@ class EngineCoreProc(EngineCore):
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: str | None = None,
tensor_queue: Queue | None = None,
*,
engine_index: int = 0,
):
......@@ -802,6 +805,12 @@ class EngineCoreProc(EngineCore):
self.engines_running = False
self.shutdown_state = EngineShutdownState.RUNNING
# Receiver for tensor IPC
self.tensor_ipc_receiver: TensorIpcReceiver | None = None
if tensor_queue is not None:
self.tensor_ipc_receiver = TensorIpcReceiver(tensor_queue)
logger.info("Using tensor IPC queue for multimodal tensor sharing")
with self._perform_handshakes(
handshake_address,
identity,
......@@ -1340,9 +1349,11 @@ class EngineCoreProc(EngineCore):
):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
# Msgpack serialization decoding with optional tensor IPC receiver.
add_request_decoder = MsgpackDecoder(
EngineCoreRequest, oob_tensor_provider=self.tensor_ipc_receiver
)
generic_decoder = MsgpackDecoder(oob_tensor_provider=self.tensor_ipc_receiver)
with ExitStack() as stack, zmq.Context() as ctx:
input_sockets = [
......@@ -1418,10 +1429,7 @@ class EngineCoreProc(EngineCore):
self.input_queue.put_nowait((request_type, request))
def process_output_sockets(
self,
output_paths: list[str],
coord_output_path: str | None,
engine_index: int,
self, output_paths: list[str], coord_output_path: str | None, engine_index: int
):
"""Output socket IO thread."""
......@@ -1580,6 +1588,7 @@ class DPEngineCoreProc(EngineCoreProc):
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: str | None = None,
tensor_queue: Queue | None = None,
):
assert vllm_config.model_config.is_moe, (
"DPEngineCoreProc should only be used for MoE models"
......@@ -1605,6 +1614,7 @@ class DPEngineCoreProc(EngineCoreProc):
log_stats,
client_handshake_address,
engine_index=dp_rank,
tensor_queue=tensor_queue,
)
def _init_data_parallel(self, vllm_config: VllmConfig):
......
......@@ -12,6 +12,7 @@ from collections import defaultdict, deque
from collections.abc import Awaitable, Callable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass
from multiprocessing.queues import Queue
from threading import Thread
from typing import Any, TypeAlias, TypeVar
......@@ -45,6 +46,7 @@ from vllm.v1.engine import (
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.engine.tensor_ipc import TensorIpcSender
from vllm.v1.engine.utils import (
CoreEngineActorManager,
CoreEngineProcManager,
......@@ -477,9 +479,6 @@ class MPClient(EngineCoreClient):
client_addresses: dict[str, str] | None = None,
):
self.vllm_config = vllm_config
# Serialization setup.
self.encoder = MsgpackEncoder()
self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup.
sync_ctx = zmq.Context(io_threads=2)
......@@ -501,11 +500,14 @@ class MPClient(EngineCoreClient):
enable_input_socket_handover = parallel_config.enable_elastic_ep
self.stats_update_address: str | None = None
tensor_queue: Queue | None = None
if client_addresses:
# Engines are managed externally to this client.
input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get("stats_update_address")
# Tensor queues passed via client_addresses for multi-API-server case
tensor_queue = client_addresses.get("tensor_queue") # type: ignore[assignment]
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx,
input_address,
......@@ -532,7 +534,7 @@ class MPClient(EngineCoreClient):
with launch_core_engines(
vllm_config, executor_class, log_stats, addresses
) as (engine_manager, coordinator, addresses):
) as (engine_manager, coordinator, addresses, tensor_queue):
self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager
......@@ -542,6 +544,17 @@ class MPClient(EngineCoreClient):
coordinator.get_stats_publish_address()
)
# Serialization setup with tensor queues for multimodal tensor IPC.
tensor_ipc_sender: TensorIpcSender | None = None
model_config = getattr(vllm_config, "model_config", None)
if model_config is not None and model_config.multimodal_config is not None:
mm_tensor_ipc = model_config.multimodal_config.mm_tensor_ipc
if mm_tensor_ipc == "torch_shm" and tensor_queue is not None:
tensor_ipc_sender = TensorIpcSender(tensor_queue)
self.encoder = MsgpackEncoder(oob_tensor_consumer=tensor_ipc_sender)
self.decoder = MsgpackDecoder(EngineCoreOutputs)
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_index
dp_local_size = parallel_config.data_parallel_size_local
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tensor IPC transport via torch.multiprocessing.Queue.
This module contains the queue-based transport logic for sharing tensors
between processes (e.g., API server -> engine core). The msgpack layer
emits/consumes lightweight :class:`TensorIpcData` values, while transport
state such as request association, handle generation, queue routing, buffering,
and cleanup lives here.
"""
import dataclasses
import uuid
from collections import defaultdict
from dataclasses import field
from multiprocessing.queues import Queue as MPQueue
from typing import Any
import torch
from vllm.logger import init_logger
from vllm.v1.serial_utils import OOBTensorConsumer
logger = init_logger(__name__)
TensorIpcQueue = MPQueue
@dataclasses.dataclass
class TensorIpcData:
"""
Data sent via torch.multiprocessing.Queue for zero-copy IPC.
Contains the tensor_id and the actual tensor. The tensor is
shared in memory (GPU or CPU) for efficient inter-process communication.
"""
sender_id: str
message_id: int
tensor_id: int
tensor: torch.Tensor
class TensorIpcSender(OOBTensorConsumer):
"""Send-side logic for tensor IPC via torch.multiprocessing.Queue.
Uses a single queue targeting rank 0 (the only rank that consumes
multimodal tensors during TP>1 / PP>1. Note: DP>1 not supported).
"""
def __init__(self, queue: TensorIpcQueue):
self.queue = queue
self._tensor_id_counter = 0
self._message_counter = 0
self._sender_id = uuid.uuid4().hex[:8]
def set_target_engine(self, target_engine: int) -> None:
if target_engine != 0:
raise IndexError(
"TensorIpcSender only supports a single queue; "
f"got target engine {target_engine}"
)
def new_message(self) -> None:
self._message_counter += 1
self._tensor_id_counter = 0
def __call__(self, tensor: torch.Tensor) -> dict[str, Any] | None:
"""Send tensor via queue, return its handle. Returns None if failed."""
try:
# Move tensor to shared memory for IPC
# This is required for proper inter-process communication
if not tensor.is_shared():
tensor = tensor.share_memory_()
metadata = {
"sender_id": self._sender_id,
"message_id": self._message_counter,
"tensor_id": self._tensor_id_counter,
}
self._tensor_id_counter += 1
ipc_data = TensorIpcData(**metadata, tensor=tensor) # type: ignore[arg-type]
# Use a timeout to avoid blocking indefinitely
self.queue.put(ipc_data, timeout=10.0)
logger.debug(
"Sent tensor %s for (shape=%s, device=%s) "
"via IPC queue (shared memory)",
metadata,
tensor.shape,
tensor.device,
)
return metadata
except Exception as e:
logger.warning(
"Failed to send tensor via IPC queue: %s. "
"Falling back to standard serialization.",
e,
)
return None
@dataclasses.dataclass
class _Sender:
current_message_id: int = -1
tensors: dict[int, dict[int, torch.Tensor]] = field(default_factory=dict)
class TensorIpcReceiver:
"""Receive-side logic for tensor IPC via torch.multiprocessing.Queue.
Wraps the queue receive logic previously embedded in MsgpackDecoder.
"""
def __init__(self, queue: TensorIpcQueue):
self.queue = queue
self._tensor_buffers = defaultdict[str, _Sender](_Sender)
def __call__(
self, dtype: str, shape: tuple[int, ...], meta: dict[str, Any]
) -> torch.Tensor:
"""Retrieve a tensor from torch.multiprocessing.Queue.
Uses a drain-and-buffer pattern: drains all available tensors from
the queue, buffering them, until the requested tensor is found.
Works for CUDA and CPU.
"""
# Create lookup key from handle
sender_id: str = meta["sender_id"]
message_id: int = meta["message_id"]
tensor_id: int = meta["tensor_id"]
# Drain all available tensors. We save them regardless if this is
# the one we're waiting for as they may arrive out of order from
# multiple producers.
while True:
sender = self._tensor_buffers.get(sender_id)
if sender is not None:
tensors = sender.tensors
tensor = tensors.get(message_id, {}).pop(tensor_id, None)
if tensor is not None:
if sender.current_message_id != message_id:
while tensors and (mid := next(iter(tensors))) < message_id:
if sender.tensors.pop(mid):
logger.warning(
"Discarding %d stale tensors from sender %s",
sender_id,
)
sender.current_message_id = message_id
logger.debug(
"Received tensor %s from sender %s for (shape=%s, device=%s) "
"via IPC queue (shared memory)",
(message_id, tensor_id),
sender_id,
tensor.shape,
tensor.device,
)
return tensor
ipc_data: TensorIpcData = self.queue.get(timeout=10.0)
# Store tensor
sender = self._tensor_buffers[ipc_data.sender_id]
if sender.current_message_id > ipc_data.message_id:
logger.warning(
"Ignoring stale tensor from sender %s", ipc_data.sender_id
)
continue
sender.tensors.setdefault(ipc_data.message_id, {})[ipc_data.tensor_id] = (
ipc_data.tensor
)
......@@ -10,6 +10,7 @@ from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import TYPE_CHECKING
from unittest.mock import patch
......@@ -95,6 +96,7 @@ class CoreEngineProcManager:
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: str | None = None,
tensor_queue: Queue | None = None,
):
context = get_mp_context()
common_kwargs = {
......@@ -103,6 +105,7 @@ class CoreEngineProcManager:
"handshake_address": handshake_address,
"executor_class": executor_class,
"log_stats": log_stats,
"tensor_queue": tensor_queue,
}
if client_handshake_address:
......@@ -864,6 +867,7 @@ def launch_core_engines(
CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None,
EngineZmqAddresses,
Queue | None,
]
]:
"""Launch engine and DP coordinator processes as needed."""
......@@ -878,6 +882,14 @@ def launch_core_engines(
offline_mode = local_start_index is not None
# Create a single tensor IPC queue for sharing multimodal tensors between
# API servers and engine core. Returns a single queue since we only support
# DP=1 for this data flow.
tensor_queue: Queue | None = None
multimodal_config = vllm_config.model_config.multimodal_config
if multimodal_config is not None and multimodal_config.mm_tensor_ipc == "torch_shm":
tensor_queue = get_mp_context().Queue()
# Run the DP Coordinator process with rank 0 when in online DP mode.
# The coordinator is needed for:
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
......@@ -913,7 +925,7 @@ def launch_core_engines(
log_stats=log_stats,
)
yield engine_actor_manager, coordinator, addresses
yield engine_actor_manager, coordinator, addresses, tensor_queue
return
if offline_mode:
......@@ -975,11 +987,12 @@ def launch_core_engines(
local_engine_count=local_engine_count,
start_index=dp_rank,
local_start_index=local_start_index or 0,
tensor_queue=tensor_queue,
)
else:
local_engine_manager = None
yield local_engine_manager, coordinator, addresses
yield local_engine_manager, coordinator, addresses, tensor_queue
# Now wait for engines to start.
wait_for_engine_startup(
......
......@@ -4,6 +4,7 @@
import dataclasses
import importlib
import pickle
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from functools import partial
from inspect import isclass
......@@ -53,6 +54,27 @@ MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
class OOBTensorConsumer(ABC):
@abstractmethod
def __call__(self, tensor: torch.Tensor) -> dict | None:
"""
Called with tensors for the current message.
Returns None to reject the tensor (falls back to regular serialization),
otherwise a dict with arbitrary placeholder data to be included
in the serialized message.
"""
return None
@abstractmethod
def new_message(self) -> None:
"""Called at the start of each new encoded message."""
pass
# dtype, shape, metadata -> tensor
OOBTensorProvider = Callable[[str, tuple[int, ...], dict], torch.Tensor]
def _log_insecure_serialization_warning():
logger.warning_once(
"Allowing insecure serialization using pickle due to "
......@@ -119,9 +141,16 @@ class MsgpackEncoder:
By default, arrays below 256B are serialized inline Larger will get sent
via dedicated messages. Note that this is a per-tensor limit.
When a ``oob_tensor_consumer`` is provided, tensors (CUDA and CPU) will be
offered to it for out-of-band handling.
"""
def __init__(self, size_threshold: int | None = None):
def __init__(
self,
size_threshold: int | None = None,
oob_tensor_consumer: OOBTensorConsumer | None = None,
):
if size_threshold is None:
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
......@@ -130,11 +159,14 @@ class MsgpackEncoder:
# pass custom data to the hook otherwise.
self.aux_buffers: list[bytestr] | None = None
self.size_threshold = size_threshold
self.oob_tensor_consumer = oob_tensor_consumer
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
_log_insecure_serialization_warning()
def encode(self, obj: Any) -> Sequence[bytestr]:
try:
if self.oob_tensor_consumer is not None:
self.oob_tensor_consumer.new_message()
self.aux_buffers = bufs = [b""]
bufs[0] = self.encoder.encode(obj)
# This `bufs` list allows us to collect direct pointers to backing
......@@ -147,6 +179,8 @@ class MsgpackEncoder:
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
try:
if self.oob_tensor_consumer is not None:
self.oob_tensor_consumer.new_message()
self.aux_buffers = [buf]
bufs = self.aux_buffers
self.encoder.encode_into(obj, buf)
......@@ -222,17 +256,19 @@ class MsgpackEncoder:
def _encode_tensor(
self, obj: torch.Tensor
) -> tuple[str, tuple[int, ...], int | memoryview]:
assert self.aux_buffers is not None
) -> tuple[str, tuple[int, ...], int | dict | memoryview]:
oob_consumer = self.oob_tensor_consumer
# view the tensor as a contiguous 1D array of bytes
arr_data = tensor_data(obj)
if obj.nbytes < self.size_threshold:
if obj.nbytes < self.size_threshold and obj.is_cpu:
# Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, tensor_data(obj))
elif oob_consumer is not None and (data := oob_consumer(obj)) is not None:
assert isinstance(data, dict)
else:
# Otherwise encode index of backing buffer to avoid copy.
assert self.aux_buffers is not None
data = len(self.aux_buffers)
self.aux_buffers.append(arr_data)
self.aux_buffers.append(tensor_data(obj))
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data
......@@ -279,9 +315,17 @@ class MsgpackDecoder:
Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
``oob_tensor_provider`` must be used when an OOBTensorConsumer is used on the
encoder side.
"""
def __init__(self, t: Any | None = None, share_mem: bool = True):
def __init__(
self,
t: Any | None = None,
share_mem: bool = True,
oob_tensor_provider: OOBTensorProvider | None = None,
):
self.share_mem = share_mem
self.pin_tensors = is_pin_memory_available()
args = () if t is None else (t,)
......@@ -289,6 +333,7 @@ class MsgpackDecoder:
*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
)
self.aux_buffers: Sequence[bytestr] = ()
self.oob_tensor_provider = oob_tensor_provider
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
_log_insecure_serialization_warning()
......@@ -353,6 +398,12 @@ class MsgpackDecoder:
def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr
if isinstance(data, dict):
assert self.oob_tensor_provider, (
"Received OOB tensor but tensor provider is not set"
)
return self.oob_tensor_provider(dtype, shape, data)
is_aux = isinstance(data, int)
buffer = self.aux_buffers[data] if is_aux else data
buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
......
......@@ -10,6 +10,7 @@ from contextlib import AbstractContextManager
from dataclasses import dataclass
from multiprocessing import connection
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import (
TYPE_CHECKING,
Any,
......@@ -173,6 +174,7 @@ class APIServerProcessManager:
input_addresses: list[str],
output_addresses: list[str],
stats_update_address: str | None = None,
tensor_queue: Queue | None = None,
):
"""Initialize and start API server worker processes.
......@@ -185,6 +187,7 @@ class APIServerProcessManager:
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
tensor_queue: Optional tensor IPC queue for sharing MM tensors
"""
self.listen_address = listen_address
self.sock = sock
......@@ -205,6 +208,8 @@ class APIServerProcessManager:
}
if stats_update_address is not None:
client_config["stats_update_address"] = stats_update_address
if tensor_queue is not None:
client_config["tensor_queue"] = tensor_queue
proc = spawn_context.Process(
target=target_server_fn,
......@@ -419,7 +424,7 @@ def tensor_data(tensor: torch.Tensor) -> memoryview:
Returns:
A memoryview of the tensor data as uint8.
"""
return tensor.flatten().contiguous().view(torch.uint8).numpy().data
return tensor.flatten().cpu().contiguous().view(torch.uint8).numpy().data
@dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment