"tests/vscode:/vscode.git/clone" did not exist on "3ffa52009f35c4398a86b7cdd83d4031bf19651c"
Unverified Commit 80b70884 authored by Brandon Pelfrey's avatar Brandon Pelfrey Committed by GitHub
Browse files

Add tensor IPC transfer mechanism for multimodal data (#32104)


Signed-off-by: default avatarBrandon Pelfrey <bpelfrey@nvidia.com>
Signed-off-by: default avatarBrandon Pelfrey <brandonpelfrey@gmail.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 61e381dc
...@@ -278,3 +278,148 @@ def test_custom_class_serialization_disallowed_without_pickle(): ...@@ -278,3 +278,148 @@ def test_custom_class_serialization_disallowed_without_pickle():
with pytest.raises(TypeError): with pytest.raises(TypeError):
# Attempt to encode the custom class # Attempt to encode the custom class
encoder.encode(obj) encoder.encode(obj)
@dataclass
class RequestWithTensor:
"""Mock request with non-multimodal tensor field like EngineCoreRequest."""
prompt_embeds: torch.Tensor | None
data: str
def test_non_multimodal_tensor_with_ipc():
"""Test that non-multimodal tensor fields work correctly with IPC enabled.
This reproduces the bug where fields like prompt_embeds: torch.Tensor | None
would fail to decode when IPC is enabled because _decode_tensor expected a
raw tensor tuple but received a msgpack-decoded TensorIpcHandle list.
"""
import torch.multiprocessing as torch_mp
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
# Create tensor queues for IPC
tensor_queues = [torch_mp.Queue()]
# Create encoder with IPC sender
sender = TensorIpcSender(tensor_queues[0])
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
# Create decoder with IPC receiver
receiver = TensorIpcReceiver(tensor_queues[0])
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
# Create a request with a non-multimodal tensor
original_tensor = torch.randn(5, 10, dtype=torch.float32)
request = RequestWithTensor(prompt_embeds=original_tensor, data="test_data")
# Encode the request - this should send the tensor via IPC
encoded = encoder.encode(request)
# Verify encoding succeeded
assert len(encoded) > 0
# Decode the request - this should retrieve the tensor from IPC queue
# Previously this would fail because the decoder tried to unpack the
# handle list as raw tensor bytes metadata.
decoded = decoder.decode(encoded)
# Verify the decoded request matches the original
assert isinstance(decoded, RequestWithTensor)
assert decoded.data == "test_data"
assert decoded.prompt_embeds is not None
assert torch.allclose(decoded.prompt_embeds, original_tensor), (
"Decoded tensor does not match the original tensor."
)
def test_non_multimodal_tensor_with_ipc_none_value():
"""Test that None values for tensor fields work correctly with IPC enabled."""
import torch.multiprocessing as torch_mp
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
# Create tensor queues for IPC
tensor_queues = [torch_mp.Queue()]
# Create encoder with IPC sender
sender = TensorIpcSender(tensor_queues[0])
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
# Create decoder with IPC receiver
receiver = TensorIpcReceiver(tensor_queues[0])
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
# Create a request with None for the tensor field
request = RequestWithTensor(prompt_embeds=None, data="test_data_with_none")
# Encode and decode the request
encoded = encoder.encode(request)
decoded = decoder.decode(encoded)
# Verify the decoded request matches the original
assert isinstance(decoded, RequestWithTensor)
assert decoded.data == "test_data_with_none"
assert decoded.prompt_embeds is None
def test_multiple_senders_single_receiver_ipc():
"""Test N senders sharing a queue with a single receiver via msgpack.
Simulates the real vLLM topology where multiple API server frontends
each have their own MsgpackEncoder + TensorIpcSender, all putting
tensors onto the same torch.mp queue, and a single engine core
decodes them with one MsgpackDecoder + TensorIpcReceiver.
"""
import torch.multiprocessing as torch_mp
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
num_senders = 3
num_messages_per_sender = 2
tensor_queue = torch_mp.Queue()
# Create N independent senders (each gets its own uuid-based sender_id)
senders = []
encoders = []
for _ in range(num_senders):
s = TensorIpcSender(tensor_queue)
senders.append(s)
encoders.append(MsgpackEncoder(oob_tensor_consumer=s))
# Single receiver
receiver = TensorIpcReceiver(tensor_queue)
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
# Encode messages from all senders, interleaving the order
# so that tensors from different senders land on the queue interleaved.
encoded_payloads: list[tuple[int, int, torch.Tensor, list]] = []
for msg_idx in range(num_messages_per_sender):
for sender_idx in range(num_senders):
tensor = torch.full(
(sender_idx + 1, msg_idx + 2),
float(sender_idx * 100 + msg_idx),
dtype=torch.float32,
)
req = RequestWithTensor(
prompt_embeds=tensor,
data=f"s{sender_idx}_m{msg_idx}",
)
encoded = encoders[sender_idx].encode(req)
encoded_payloads.append((sender_idx, msg_idx, tensor, encoded))
# Decode all messages — the receiver must correctly match each
# tensor handle to the right TensorIpcData from the shared queue.
for sender_idx, msg_idx, original_tensor, encoded in encoded_payloads:
decoded = decoder.decode(encoded)
assert isinstance(decoded, RequestWithTensor)
assert decoded.data == f"s{sender_idx}_m{msg_idx}"
assert decoded.prompt_embeds is not None
assert decoded.prompt_embeds.shape == original_tensor.shape, (
f"Shape mismatch for sender {sender_idx} msg {msg_idx}: "
f"{decoded.prompt_embeds.shape} != {original_tensor.shape}"
)
assert torch.allclose(decoded.prompt_embeds, original_tensor), (
f"Value mismatch for sender {sender_idx} msg {msg_idx}"
)
This diff is collapsed.
...@@ -14,7 +14,12 @@ import vllm.envs as envs ...@@ -14,7 +14,12 @@ import vllm.envs as envs
from vllm.config.model_arch import ( from vllm.config.model_arch import (
ModelArchitectureConfig, ModelArchitectureConfig,
) )
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.multimodal import (
MMCacheType,
MMEncoderTPMode,
MMTensorIPC,
MultiModalConfig,
)
from vllm.config.pooler import PoolerConfig from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType from vllm.config.scheduler import RunnerType
from vllm.config.utils import config, getattr_iter from vllm.config.utils import config, getattr_iter
...@@ -310,6 +315,7 @@ class ModelConfig: ...@@ -310,6 +315,7 @@ class ModelConfig:
interleave_mm_strings: InitVar[bool | None] = None interleave_mm_strings: InitVar[bool | None] = None
skip_mm_profiling: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None
video_pruning_rate: InitVar[float | None] = None video_pruning_rate: InitVar[float | None] = None
mm_tensor_ipc: InitVar[MMTensorIPC] = None
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
...@@ -430,6 +436,7 @@ class ModelConfig: ...@@ -430,6 +436,7 @@ class ModelConfig:
interleave_mm_strings: bool | None, interleave_mm_strings: bool | None,
skip_mm_profiling: bool | None, skip_mm_profiling: bool | None,
video_pruning_rate: float | None, video_pruning_rate: float | None,
mm_tensor_ipc: MMTensorIPC,
) -> None: ) -> None:
# Keep set served_model_name before maybe_model_redirect(self.model) # Keep set served_model_name before maybe_model_redirect(self.model)
self.served_model_name = get_served_model_name( self.served_model_name = get_served_model_name(
...@@ -612,6 +619,7 @@ class ModelConfig: ...@@ -612,6 +619,7 @@ class ModelConfig:
interleave_mm_strings=interleave_mm_strings, interleave_mm_strings=interleave_mm_strings,
skip_mm_profiling=skip_mm_profiling, skip_mm_profiling=skip_mm_profiling,
video_pruning_rate=video_pruning_rate, video_pruning_rate=video_pruning_rate,
mm_tensor_ipc=mm_tensor_ipc,
) )
mm_config_kwargs = { mm_config_kwargs = {
...@@ -1112,6 +1120,22 @@ class ModelConfig: ...@@ -1112,6 +1120,22 @@ class ModelConfig:
f"({parallel_config.decode_context_parallel_size})." f"({parallel_config.decode_context_parallel_size})."
) )
# torch_shm uses a single IPC queue to rank 0; DP>1 is
# incompatible because API servers can't know which
# CoreEngine the scheduler will assign work to. TP>1 is
# also not supported because this requires broadcasting
# MM tensors between all TP ranks.
if (
self.multimodal_config is not None
and self.multimodal_config.mm_tensor_ipc == "torch_shm"
and parallel_config.world_size_across_dp > 1
):
raise ValueError(
"mm_tensor_ipc='torch_shm' is not supported with "
"data_parallel_size > 1 or tensor_parallel_size > 1 "
"or pipeline_parallel_size > 1."
)
def get_sliding_window(self) -> int | None: def get_sliding_window(self) -> int | None:
"""Get the sliding window size from the HF text config if present.""" """Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None) return getattr(self.hf_text_config, "sliding_window", None)
......
...@@ -59,6 +59,7 @@ class MultiModalDummyOptionsBuiltins(TypedDict, total=False): ...@@ -59,6 +59,7 @@ class MultiModalDummyOptionsBuiltins(TypedDict, total=False):
MMEncoderTPMode = Literal["weights", "data"] MMEncoderTPMode = Literal["weights", "data"]
MMCacheType = Literal["shm", "lru"] MMCacheType = Literal["shm", "lru"]
MMTensorIPC = Literal["direct_rpc", "torch_shm"]
MMDummyOptions: TypeAlias = dict[str, BaseDummyOptions] MMDummyOptions: TypeAlias = dict[str, BaseDummyOptions]
""" """
A dictionary containing an entry for each modality type of dummy data. A dictionary containing an entry for each modality type of dummy data.
...@@ -172,6 +173,11 @@ class MultiModalConfig: ...@@ -172,6 +173,11 @@ class MultiModalConfig:
Value sits in range [0;1) and determines fraction of media tokens Value sits in range [0;1) and determines fraction of media tokens
from each video to be pruned. from each video to be pruned.
""" """
mm_tensor_ipc: MMTensorIPC = "direct_rpc"
"""IPC (inter-process communication) method for multimodal tensors.
- "direct_rpc": Use msgspec serialization via RPC
- "torch_shm": Use torch.multiprocessing shared memory for zero-copy IPC
Defaults to "direct_rpc". """
@field_validator("limit_per_prompt", mode="before") @field_validator("limit_per_prompt", mode="before")
@classmethod @classmethod
......
...@@ -766,6 +766,17 @@ class VllmConfig: ...@@ -766,6 +766,17 @@ class VllmConfig:
else: else:
self.parallel_config.disable_nccl_for_dp_synchronization = False self.parallel_config.disable_nccl_for_dp_synchronization = False
if (
self.model_config is not None
and self.model_config.multimodal_config is not None
and self.model_config.multimodal_config.mm_tensor_ipc == "torch_shm"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
raise ValueError(
"torch_shm is known to fail without "
"VLLM_WORKER_MULTIPROC_METHOD set to spawn"
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
if ( if (
......
...@@ -79,7 +79,7 @@ from vllm.config.model import ( ...@@ -79,7 +79,7 @@ from vllm.config.model import (
RunnerOption, RunnerOption,
TokenizerMode, TokenizerMode,
) )
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MMTensorIPC
from vllm.config.observability import DetailedTraceModules from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import ( from vllm.config.parallel import (
All2AllBackend, All2AllBackend,
...@@ -509,6 +509,7 @@ class EngineArgs: ...@@ -509,6 +509,7 @@ class EngineArgs:
io_processor_plugin: str | None = None io_processor_plugin: str | None = None
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
mm_tensor_ipc: MMTensorIPC = MultiModalConfig.mm_tensor_ipc
# LoRA fields # LoRA fields
enable_lora: bool = False enable_lora: bool = False
max_loras: int = LoRAConfig.max_loras max_loras: int = LoRAConfig.max_loras
...@@ -1097,6 +1098,9 @@ class EngineArgs: ...@@ -1097,6 +1098,9 @@ class EngineArgs:
multimodal_group.add_argument( multimodal_group.add_argument(
"--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"] "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
) )
multimodal_group.add_argument(
"--mm-tensor-ipc", **multimodal_kwargs["mm_tensor_ipc"]
)
# LoRA related configs # LoRA related configs
lora_kwargs = get_kwargs(LoRAConfig) lora_kwargs = get_kwargs(LoRAConfig)
...@@ -1423,6 +1427,7 @@ class EngineArgs: ...@@ -1423,6 +1427,7 @@ class EngineArgs:
override_attention_dtype=self.override_attention_dtype, override_attention_dtype=self.override_attention_dtype,
logits_processors=self.logits_processors, logits_processors=self.logits_processors,
video_pruning_rate=self.video_pruning_rate, video_pruning_rate=self.video_pruning_rate,
mm_tensor_ipc=self.mm_tensor_ipc,
io_processor_plugin=self.io_processor_plugin, io_processor_plugin=self.io_processor_plugin,
) )
......
...@@ -290,7 +290,7 @@ def run_multi_api_server(args: argparse.Namespace): ...@@ -290,7 +290,7 @@ def run_multi_api_server(args: argparse.Namespace):
with launch_core_engines( with launch_core_engines(
vllm_config, executor_class, log_stats, addresses, num_api_servers vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses): ) as (local_engine_manager, coordinator, addresses, tensor_queue):
# Construct common args for the APIServerProcessManager up-front. # Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict( api_server_manager_kwargs = dict(
target_server_fn=run_api_server_worker_proc, target_server_fn=run_api_server_worker_proc,
...@@ -303,6 +303,7 @@ def run_multi_api_server(args: argparse.Namespace): ...@@ -303,6 +303,7 @@ def run_multi_api_server(args: argparse.Namespace):
stats_update_address=coordinator.get_stats_publish_address() stats_update_address=coordinator.get_stats_publish_address()
if coordinator if coordinator
else None, else None,
tensor_queue=tensor_queue,
) )
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
......
...@@ -13,6 +13,7 @@ from enum import IntEnum ...@@ -13,6 +13,7 @@ from enum import IntEnum
from functools import partial from functools import partial
from inspect import isclass, signature from inspect import isclass, signature
from logging import DEBUG from logging import DEBUG
from multiprocessing.queues import Queue
from typing import Any, TypeVar, cast from typing import Any, TypeVar, cast
import msgspec import msgspec
...@@ -59,6 +60,7 @@ from vllm.v1.engine import ( ...@@ -59,6 +60,7 @@ from vllm.v1.engine import (
UtilityOutput, UtilityOutput,
UtilityResult, UtilityResult,
) )
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver
from vllm.v1.engine.utils import ( from vllm.v1.engine.utils import (
EngineHandshakeMetadata, EngineHandshakeMetadata,
EngineZmqAddresses, EngineZmqAddresses,
...@@ -788,6 +790,7 @@ class EngineCoreProc(EngineCore): ...@@ -788,6 +790,7 @@ class EngineCoreProc(EngineCore):
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
client_handshake_address: str | None = None, client_handshake_address: str | None = None,
tensor_queue: Queue | None = None,
*, *,
engine_index: int = 0, engine_index: int = 0,
): ):
...@@ -802,6 +805,12 @@ class EngineCoreProc(EngineCore): ...@@ -802,6 +805,12 @@ class EngineCoreProc(EngineCore):
self.engines_running = False self.engines_running = False
self.shutdown_state = EngineShutdownState.RUNNING self.shutdown_state = EngineShutdownState.RUNNING
# Receiver for tensor IPC
self.tensor_ipc_receiver: TensorIpcReceiver | None = None
if tensor_queue is not None:
self.tensor_ipc_receiver = TensorIpcReceiver(tensor_queue)
logger.info("Using tensor IPC queue for multimodal tensor sharing")
with self._perform_handshakes( with self._perform_handshakes(
handshake_address, handshake_address,
identity, identity,
...@@ -1340,9 +1349,11 @@ class EngineCoreProc(EngineCore): ...@@ -1340,9 +1349,11 @@ class EngineCoreProc(EngineCore):
): ):
"""Input socket IO thread.""" """Input socket IO thread."""
# Msgpack serialization decoding. # Msgpack serialization decoding with optional tensor IPC receiver.
add_request_decoder = MsgpackDecoder(EngineCoreRequest) add_request_decoder = MsgpackDecoder(
generic_decoder = MsgpackDecoder() EngineCoreRequest, oob_tensor_provider=self.tensor_ipc_receiver
)
generic_decoder = MsgpackDecoder(oob_tensor_provider=self.tensor_ipc_receiver)
with ExitStack() as stack, zmq.Context() as ctx: with ExitStack() as stack, zmq.Context() as ctx:
input_sockets = [ input_sockets = [
...@@ -1418,10 +1429,7 @@ class EngineCoreProc(EngineCore): ...@@ -1418,10 +1429,7 @@ class EngineCoreProc(EngineCore):
self.input_queue.put_nowait((request_type, request)) self.input_queue.put_nowait((request_type, request))
def process_output_sockets( def process_output_sockets(
self, self, output_paths: list[str], coord_output_path: str | None, engine_index: int
output_paths: list[str],
coord_output_path: str | None,
engine_index: int,
): ):
"""Output socket IO thread.""" """Output socket IO thread."""
...@@ -1580,6 +1588,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1580,6 +1588,7 @@ class DPEngineCoreProc(EngineCoreProc):
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
client_handshake_address: str | None = None, client_handshake_address: str | None = None,
tensor_queue: Queue | None = None,
): ):
assert vllm_config.model_config.is_moe, ( assert vllm_config.model_config.is_moe, (
"DPEngineCoreProc should only be used for MoE models" "DPEngineCoreProc should only be used for MoE models"
...@@ -1605,6 +1614,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1605,6 +1614,7 @@ class DPEngineCoreProc(EngineCoreProc):
log_stats, log_stats,
client_handshake_address, client_handshake_address,
engine_index=dp_rank, engine_index=dp_rank,
tensor_queue=tensor_queue,
) )
def _init_data_parallel(self, vllm_config: VllmConfig): def _init_data_parallel(self, vllm_config: VllmConfig):
......
...@@ -12,6 +12,7 @@ from collections import defaultdict, deque ...@@ -12,6 +12,7 @@ from collections import defaultdict, deque
from collections.abc import Awaitable, Callable, Sequence from collections.abc import Awaitable, Callable, Sequence
from concurrent.futures import Future from concurrent.futures import Future
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing.queues import Queue
from threading import Thread from threading import Thread
from typing import Any, TypeAlias, TypeVar from typing import Any, TypeAlias, TypeVar
...@@ -45,6 +46,7 @@ from vllm.v1.engine import ( ...@@ -45,6 +46,7 @@ from vllm.v1.engine import (
from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.engine.tensor_ipc import TensorIpcSender
from vllm.v1.engine.utils import ( from vllm.v1.engine.utils import (
CoreEngineActorManager, CoreEngineActorManager,
CoreEngineProcManager, CoreEngineProcManager,
...@@ -477,9 +479,6 @@ class MPClient(EngineCoreClient): ...@@ -477,9 +479,6 @@ class MPClient(EngineCoreClient):
client_addresses: dict[str, str] | None = None, client_addresses: dict[str, str] | None = None,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
# Serialization setup.
self.encoder = MsgpackEncoder()
self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup. # ZMQ setup.
sync_ctx = zmq.Context(io_threads=2) sync_ctx = zmq.Context(io_threads=2)
...@@ -501,11 +500,14 @@ class MPClient(EngineCoreClient): ...@@ -501,11 +500,14 @@ class MPClient(EngineCoreClient):
enable_input_socket_handover = parallel_config.enable_elastic_ep enable_input_socket_handover = parallel_config.enable_elastic_ep
self.stats_update_address: str | None = None self.stats_update_address: str | None = None
tensor_queue: Queue | None = None
if client_addresses: if client_addresses:
# Engines are managed externally to this client. # Engines are managed externally to this client.
input_address = client_addresses["input_address"] input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"] output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get("stats_update_address") self.stats_update_address = client_addresses.get("stats_update_address")
# Tensor queues passed via client_addresses for multi-API-server case
tensor_queue = client_addresses.get("tensor_queue") # type: ignore[assignment]
self.input_socket = self.resources.input_socket = make_zmq_socket( self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, self.ctx,
input_address, input_address,
...@@ -532,7 +534,7 @@ class MPClient(EngineCoreClient): ...@@ -532,7 +534,7 @@ class MPClient(EngineCoreClient):
with launch_core_engines( with launch_core_engines(
vllm_config, executor_class, log_stats, addresses vllm_config, executor_class, log_stats, addresses
) as (engine_manager, coordinator, addresses): ) as (engine_manager, coordinator, addresses, tensor_queue):
self.resources.coordinator = coordinator self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager self.resources.engine_manager = engine_manager
...@@ -542,6 +544,17 @@ class MPClient(EngineCoreClient): ...@@ -542,6 +544,17 @@ class MPClient(EngineCoreClient):
coordinator.get_stats_publish_address() coordinator.get_stats_publish_address()
) )
# Serialization setup with tensor queues for multimodal tensor IPC.
tensor_ipc_sender: TensorIpcSender | None = None
model_config = getattr(vllm_config, "model_config", None)
if model_config is not None and model_config.multimodal_config is not None:
mm_tensor_ipc = model_config.multimodal_config.mm_tensor_ipc
if mm_tensor_ipc == "torch_shm" and tensor_queue is not None:
tensor_ipc_sender = TensorIpcSender(tensor_queue)
self.encoder = MsgpackEncoder(oob_tensor_consumer=tensor_ipc_sender)
self.decoder = MsgpackDecoder(EngineCoreOutputs)
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_index dp_rank = parallel_config.data_parallel_index
dp_local_size = parallel_config.data_parallel_size_local dp_local_size = parallel_config.data_parallel_size_local
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tensor IPC transport via torch.multiprocessing.Queue.
This module contains the queue-based transport logic for sharing tensors
between processes (e.g., API server -> engine core). The msgpack layer
emits/consumes lightweight :class:`TensorIpcData` values, while transport
state such as request association, handle generation, queue routing, buffering,
and cleanup lives here.
"""
import dataclasses
import uuid
from collections import defaultdict
from dataclasses import field
from multiprocessing.queues import Queue as MPQueue
from typing import Any
import torch
from vllm.logger import init_logger
from vllm.v1.serial_utils import OOBTensorConsumer
logger = init_logger(__name__)
TensorIpcQueue = MPQueue
@dataclasses.dataclass
class TensorIpcData:
"""
Data sent via torch.multiprocessing.Queue for zero-copy IPC.
Contains the tensor_id and the actual tensor. The tensor is
shared in memory (GPU or CPU) for efficient inter-process communication.
"""
sender_id: str
message_id: int
tensor_id: int
tensor: torch.Tensor
class TensorIpcSender(OOBTensorConsumer):
"""Send-side logic for tensor IPC via torch.multiprocessing.Queue.
Uses a single queue targeting rank 0 (the only rank that consumes
multimodal tensors during TP>1 / PP>1. Note: DP>1 not supported).
"""
def __init__(self, queue: TensorIpcQueue):
self.queue = queue
self._tensor_id_counter = 0
self._message_counter = 0
self._sender_id = uuid.uuid4().hex[:8]
def set_target_engine(self, target_engine: int) -> None:
if target_engine != 0:
raise IndexError(
"TensorIpcSender only supports a single queue; "
f"got target engine {target_engine}"
)
def new_message(self) -> None:
self._message_counter += 1
self._tensor_id_counter = 0
def __call__(self, tensor: torch.Tensor) -> dict[str, Any] | None:
"""Send tensor via queue, return its handle. Returns None if failed."""
try:
# Move tensor to shared memory for IPC
# This is required for proper inter-process communication
if not tensor.is_shared():
tensor = tensor.share_memory_()
metadata = {
"sender_id": self._sender_id,
"message_id": self._message_counter,
"tensor_id": self._tensor_id_counter,
}
self._tensor_id_counter += 1
ipc_data = TensorIpcData(**metadata, tensor=tensor) # type: ignore[arg-type]
# Use a timeout to avoid blocking indefinitely
self.queue.put(ipc_data, timeout=10.0)
logger.debug(
"Sent tensor %s for (shape=%s, device=%s) "
"via IPC queue (shared memory)",
metadata,
tensor.shape,
tensor.device,
)
return metadata
except Exception as e:
logger.warning(
"Failed to send tensor via IPC queue: %s. "
"Falling back to standard serialization.",
e,
)
return None
@dataclasses.dataclass
class _Sender:
current_message_id: int = -1
tensors: dict[int, dict[int, torch.Tensor]] = field(default_factory=dict)
class TensorIpcReceiver:
"""Receive-side logic for tensor IPC via torch.multiprocessing.Queue.
Wraps the queue receive logic previously embedded in MsgpackDecoder.
"""
def __init__(self, queue: TensorIpcQueue):
self.queue = queue
self._tensor_buffers = defaultdict[str, _Sender](_Sender)
def __call__(
self, dtype: str, shape: tuple[int, ...], meta: dict[str, Any]
) -> torch.Tensor:
"""Retrieve a tensor from torch.multiprocessing.Queue.
Uses a drain-and-buffer pattern: drains all available tensors from
the queue, buffering them, until the requested tensor is found.
Works for CUDA and CPU.
"""
# Create lookup key from handle
sender_id: str = meta["sender_id"]
message_id: int = meta["message_id"]
tensor_id: int = meta["tensor_id"]
# Drain all available tensors. We save them regardless if this is
# the one we're waiting for as they may arrive out of order from
# multiple producers.
while True:
sender = self._tensor_buffers.get(sender_id)
if sender is not None:
tensors = sender.tensors
tensor = tensors.get(message_id, {}).pop(tensor_id, None)
if tensor is not None:
if sender.current_message_id != message_id:
while tensors and (mid := next(iter(tensors))) < message_id:
if sender.tensors.pop(mid):
logger.warning(
"Discarding %d stale tensors from sender %s",
sender_id,
)
sender.current_message_id = message_id
logger.debug(
"Received tensor %s from sender %s for (shape=%s, device=%s) "
"via IPC queue (shared memory)",
(message_id, tensor_id),
sender_id,
tensor.shape,
tensor.device,
)
return tensor
ipc_data: TensorIpcData = self.queue.get(timeout=10.0)
# Store tensor
sender = self._tensor_buffers[ipc_data.sender_id]
if sender.current_message_id > ipc_data.message_id:
logger.warning(
"Ignoring stale tensor from sender %s", ipc_data.sender_id
)
continue
sender.tensors.setdefault(ipc_data.message_id, {})[ipc_data.tensor_id] = (
ipc_data.tensor
)
...@@ -10,6 +10,7 @@ from dataclasses import dataclass ...@@ -10,6 +10,7 @@ from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from multiprocessing import Process, connection from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from unittest.mock import patch from unittest.mock import patch
...@@ -95,6 +96,7 @@ class CoreEngineProcManager: ...@@ -95,6 +96,7 @@ class CoreEngineProcManager:
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
client_handshake_address: str | None = None, client_handshake_address: str | None = None,
tensor_queue: Queue | None = None,
): ):
context = get_mp_context() context = get_mp_context()
common_kwargs = { common_kwargs = {
...@@ -103,6 +105,7 @@ class CoreEngineProcManager: ...@@ -103,6 +105,7 @@ class CoreEngineProcManager:
"handshake_address": handshake_address, "handshake_address": handshake_address,
"executor_class": executor_class, "executor_class": executor_class,
"log_stats": log_stats, "log_stats": log_stats,
"tensor_queue": tensor_queue,
} }
if client_handshake_address: if client_handshake_address:
...@@ -864,6 +867,7 @@ def launch_core_engines( ...@@ -864,6 +867,7 @@ def launch_core_engines(
CoreEngineProcManager | CoreEngineActorManager | None, CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None, DPCoordinator | None,
EngineZmqAddresses, EngineZmqAddresses,
Queue | None,
] ]
]: ]:
"""Launch engine and DP coordinator processes as needed.""" """Launch engine and DP coordinator processes as needed."""
...@@ -878,6 +882,14 @@ def launch_core_engines( ...@@ -878,6 +882,14 @@ def launch_core_engines(
offline_mode = local_start_index is not None offline_mode = local_start_index is not None
# Create a single tensor IPC queue for sharing multimodal tensors between
# API servers and engine core. Returns a single queue since we only support
# DP=1 for this data flow.
tensor_queue: Queue | None = None
multimodal_config = vllm_config.model_config.multimodal_config
if multimodal_config is not None and multimodal_config.mm_tensor_ipc == "torch_shm":
tensor_queue = get_mp_context().Queue()
# Run the DP Coordinator process with rank 0 when in online DP mode. # Run the DP Coordinator process with rank 0 when in online DP mode.
# The coordinator is needed for: # The coordinator is needed for:
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing # 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
...@@ -913,7 +925,7 @@ def launch_core_engines( ...@@ -913,7 +925,7 @@ def launch_core_engines(
log_stats=log_stats, log_stats=log_stats,
) )
yield engine_actor_manager, coordinator, addresses yield engine_actor_manager, coordinator, addresses, tensor_queue
return return
if offline_mode: if offline_mode:
...@@ -975,11 +987,12 @@ def launch_core_engines( ...@@ -975,11 +987,12 @@ def launch_core_engines(
local_engine_count=local_engine_count, local_engine_count=local_engine_count,
start_index=dp_rank, start_index=dp_rank,
local_start_index=local_start_index or 0, local_start_index=local_start_index or 0,
tensor_queue=tensor_queue,
) )
else: else:
local_engine_manager = None local_engine_manager = None
yield local_engine_manager, coordinator, addresses yield local_engine_manager, coordinator, addresses, tensor_queue
# Now wait for engines to start. # Now wait for engines to start.
wait_for_engine_startup( wait_for_engine_startup(
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import dataclasses import dataclasses
import importlib import importlib
import pickle import pickle
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from functools import partial from functools import partial
from inspect import isclass from inspect import isclass
...@@ -53,6 +54,27 @@ MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = { ...@@ -53,6 +54,27 @@ MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
class OOBTensorConsumer(ABC):
@abstractmethod
def __call__(self, tensor: torch.Tensor) -> dict | None:
"""
Called with tensors for the current message.
Returns None to reject the tensor (falls back to regular serialization),
otherwise a dict with arbitrary placeholder data to be included
in the serialized message.
"""
return None
@abstractmethod
def new_message(self) -> None:
"""Called at the start of each new encoded message."""
pass
# dtype, shape, metadata -> tensor
OOBTensorProvider = Callable[[str, tuple[int, ...], dict], torch.Tensor]
def _log_insecure_serialization_warning(): def _log_insecure_serialization_warning():
logger.warning_once( logger.warning_once(
"Allowing insecure serialization using pickle due to " "Allowing insecure serialization using pickle due to "
...@@ -119,9 +141,16 @@ class MsgpackEncoder: ...@@ -119,9 +141,16 @@ class MsgpackEncoder:
By default, arrays below 256B are serialized inline Larger will get sent By default, arrays below 256B are serialized inline Larger will get sent
via dedicated messages. Note that this is a per-tensor limit. via dedicated messages. Note that this is a per-tensor limit.
When a ``oob_tensor_consumer`` is provided, tensors (CUDA and CPU) will be
offered to it for out-of-band handling.
""" """
def __init__(self, size_threshold: int | None = None): def __init__(
self,
size_threshold: int | None = None,
oob_tensor_consumer: OOBTensorConsumer | None = None,
):
if size_threshold is None: if size_threshold is None:
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
...@@ -130,11 +159,14 @@ class MsgpackEncoder: ...@@ -130,11 +159,14 @@ class MsgpackEncoder:
# pass custom data to the hook otherwise. # pass custom data to the hook otherwise.
self.aux_buffers: list[bytestr] | None = None self.aux_buffers: list[bytestr] | None = None
self.size_threshold = size_threshold self.size_threshold = size_threshold
self.oob_tensor_consumer = oob_tensor_consumer
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
_log_insecure_serialization_warning() _log_insecure_serialization_warning()
def encode(self, obj: Any) -> Sequence[bytestr]: def encode(self, obj: Any) -> Sequence[bytestr]:
try: try:
if self.oob_tensor_consumer is not None:
self.oob_tensor_consumer.new_message()
self.aux_buffers = bufs = [b""] self.aux_buffers = bufs = [b""]
bufs[0] = self.encoder.encode(obj) bufs[0] = self.encoder.encode(obj)
# This `bufs` list allows us to collect direct pointers to backing # This `bufs` list allows us to collect direct pointers to backing
...@@ -147,6 +179,8 @@ class MsgpackEncoder: ...@@ -147,6 +179,8 @@ class MsgpackEncoder:
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
try: try:
if self.oob_tensor_consumer is not None:
self.oob_tensor_consumer.new_message()
self.aux_buffers = [buf] self.aux_buffers = [buf]
bufs = self.aux_buffers bufs = self.aux_buffers
self.encoder.encode_into(obj, buf) self.encoder.encode_into(obj, buf)
...@@ -222,17 +256,19 @@ class MsgpackEncoder: ...@@ -222,17 +256,19 @@ class MsgpackEncoder:
def _encode_tensor( def _encode_tensor(
self, obj: torch.Tensor self, obj: torch.Tensor
) -> tuple[str, tuple[int, ...], int | memoryview]: ) -> tuple[str, tuple[int, ...], int | dict | memoryview]:
assert self.aux_buffers is not None oob_consumer = self.oob_tensor_consumer
# view the tensor as a contiguous 1D array of bytes # view the tensor as a contiguous 1D array of bytes
arr_data = tensor_data(obj) if obj.nbytes < self.size_threshold and obj.is_cpu:
if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays. # Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data) data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, tensor_data(obj))
elif oob_consumer is not None and (data := oob_consumer(obj)) is not None:
assert isinstance(data, dict)
else: else:
# Otherwise encode index of backing buffer to avoid copy. # Otherwise encode index of backing buffer to avoid copy.
assert self.aux_buffers is not None
data = len(self.aux_buffers) data = len(self.aux_buffers)
self.aux_buffers.append(arr_data) self.aux_buffers.append(tensor_data(obj))
dtype = str(obj.dtype).removeprefix("torch.") dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data return dtype, obj.shape, data
...@@ -279,9 +315,17 @@ class MsgpackDecoder: ...@@ -279,9 +315,17 @@ class MsgpackDecoder:
Note that unlike vanilla `msgspec` Decoders, this interface is generally Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays. not thread-safe when encoding tensors / numpy arrays.
``oob_tensor_provider`` must be used when an OOBTensorConsumer is used on the
encoder side.
""" """
def __init__(self, t: Any | None = None, share_mem: bool = True): def __init__(
self,
t: Any | None = None,
share_mem: bool = True,
oob_tensor_provider: OOBTensorProvider | None = None,
):
self.share_mem = share_mem self.share_mem = share_mem
self.pin_tensors = is_pin_memory_available() self.pin_tensors = is_pin_memory_available()
args = () if t is None else (t,) args = () if t is None else (t,)
...@@ -289,6 +333,7 @@ class MsgpackDecoder: ...@@ -289,6 +333,7 @@ class MsgpackDecoder:
*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
) )
self.aux_buffers: Sequence[bytestr] = () self.aux_buffers: Sequence[bytestr] = ()
self.oob_tensor_provider = oob_tensor_provider
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
_log_insecure_serialization_warning() _log_insecure_serialization_warning()
...@@ -353,6 +398,12 @@ class MsgpackDecoder: ...@@ -353,6 +398,12 @@ class MsgpackDecoder:
def _decode_tensor(self, arr: Any) -> torch.Tensor: def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr dtype, shape, data = arr
if isinstance(data, dict):
assert self.oob_tensor_provider, (
"Received OOB tensor but tensor provider is not set"
)
return self.oob_tensor_provider(dtype, shape, data)
is_aux = isinstance(data, int) is_aux = isinstance(data, int)
buffer = self.aux_buffers[data] if is_aux else data buffer = self.aux_buffers[data] if is_aux else data
buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer) buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
......
...@@ -10,6 +10,7 @@ from contextlib import AbstractContextManager ...@@ -10,6 +10,7 @@ from contextlib import AbstractContextManager
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import connection from multiprocessing import connection
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
...@@ -173,6 +174,7 @@ class APIServerProcessManager: ...@@ -173,6 +174,7 @@ class APIServerProcessManager:
input_addresses: list[str], input_addresses: list[str],
output_addresses: list[str], output_addresses: list[str],
stats_update_address: str | None = None, stats_update_address: str | None = None,
tensor_queue: Queue | None = None,
): ):
"""Initialize and start API server worker processes. """Initialize and start API server worker processes.
...@@ -185,6 +187,7 @@ class APIServerProcessManager: ...@@ -185,6 +187,7 @@ class APIServerProcessManager:
input_addresses: Input addresses for each API server input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address stats_update_address: Optional stats update address
tensor_queue: Optional tensor IPC queue for sharing MM tensors
""" """
self.listen_address = listen_address self.listen_address = listen_address
self.sock = sock self.sock = sock
...@@ -205,6 +208,8 @@ class APIServerProcessManager: ...@@ -205,6 +208,8 @@ class APIServerProcessManager:
} }
if stats_update_address is not None: if stats_update_address is not None:
client_config["stats_update_address"] = stats_update_address client_config["stats_update_address"] = stats_update_address
if tensor_queue is not None:
client_config["tensor_queue"] = tensor_queue
proc = spawn_context.Process( proc = spawn_context.Process(
target=target_server_fn, target=target_server_fn,
...@@ -419,7 +424,7 @@ def tensor_data(tensor: torch.Tensor) -> memoryview: ...@@ -419,7 +424,7 @@ def tensor_data(tensor: torch.Tensor) -> memoryview:
Returns: Returns:
A memoryview of the tensor data as uint8. A memoryview of the tensor data as uint8.
""" """
return tensor.flatten().contiguous().view(torch.uint8).numpy().data return tensor.flatten().cpu().contiguous().view(torch.uint8).numpy().data
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment