Unverified Commit 20ccc9b2 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

refactor: add prefill_worker_utils in vLLM (#6017)

parent 1aab7f6b
......@@ -3,10 +3,9 @@
import copy
import logging
import os
from collections import defaultdict
from typing import Any
import safetensors
import torch
from vllm.inputs.data import TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM
......@@ -15,18 +14,15 @@ import dynamo.nixl_connect as connect
from dynamo.runtime import Client, Component, DistributedRuntime
from ..handlers import BaseWorkerHandler
from ..multimodal_utils import (
ImageLoader,
MyRequestOutput,
construct_mm_data,
vLLMMultimodalRequest,
)
from ..multimodal_utils import ImageLoader, MyRequestOutput, vLLMMultimodalRequest
from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
from ..multimodal_utils.prefill_worker_utils import (
accumulate_embeddings,
load_embeddings,
)
logger = logging.getLogger(__name__)
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
"""Decode worker for disaggregated multimodal serving"""
......@@ -181,108 +177,29 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
multi_modal_data = defaultdict(list)
multi_modal_data: dict[str, Any] = defaultdict(list)
for mi in request.multimodal_inputs:
# ECConnector consumer mode: vLLM loads embeddings automatically from disk
# We need to pass multimodal_input so vLLM can generate mm_hash and look up cache
if self.config.ec_consumer_mode:
logger.debug(
f"[{request.request_id}] ECConnector consumer mode: "
f"vLLM will load embeddings from cache using mm_hash"
)
# Use PIL image loading - vLLM will detect it's already in EC cache
# and load from disk instead of reprocessing
if mi.multimodal_input.image_url:
# PIL image path — used by both EC consumer mode
# (vLLM looks up cached embeddings via mm_hash) and
# non-disaggregated mode (vLLM encodes inline).
multi_modal_data["image"].append(
await self.image_loader.load_image(
mi.multimodal_input.image_url
)
)
elif mi.multimodal_input.video_url:
# For video, load as image placeholder (vLLM will use EC cache)
multi_modal_data["image"].append(
await self.image_loader.load_image(
request.multimodal_input.video_url
)
)
else:
raise ValueError(
"ECConnector mode requires multimodal_input with image/video URL"
await self.image_loader.load_image(mi.multimodal_input.image_url)
)
elif (
mi.multimodal_input.image_url is None
and mi.multimodal_input.video_url is None
):
# Process embeddings using the connector
# Create a descriptor based on the embedding shape.
if TRANSFER_LOCAL:
logger.info("PD: Loading local safetensors file")
embeddings = safetensors.torch.load_file(mi.serialized_request)[
"ec_cache"
]
else:
embeddings = torch.empty(
mi.embeddings_shape,
dtype=self.EMBEDDINGS_DTYPE,
device=self.EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op = await self._connector.begin_read(
mi.serialized_request, descriptor
)
await read_op.wait_for_completion()
if "video" in self.config.model.lower():
video_numpy = embeddings.numpy()
mm_data = construct_mm_data(
self.config.model,
# Pre-computed embeddings via NIXL RDMA or local safetensors
embeddings = await load_embeddings(
mi,
self.EMBEDDINGS_DTYPE,
video_numpy=video_numpy,
self.EMBEDDINGS_DEVICE,
self._connector,
)
multi_modal_data["video"].append(mm_data["video"])
else:
mm_data = construct_mm_data(
accumulate_embeddings(
multi_modal_data,
self.config.model,
self.EMBEDDINGS_DTYPE,
image_embeds=embeddings,
image_grid_thw=mi.image_grid_thw,
)
if isinstance(mm_data["image"], dict):
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
# Merging tensors
multi_modal_data["image"]["image_embeds"] = torch.cat(
(
multi_modal_data["image"]["image_embeds"],
mm_data["image"]["image_embeds"],
)
)
multi_modal_data["image"]["image_grid_thw"] = torch.cat(
(
multi_modal_data["image"]["image_grid_thw"],
mm_data["image"]["image_grid_thw"],
)
)
else:
logger.info(f"Get embedding of shape {mm_data['image'].shape}")
# [gluo FIXME] embedding with multiple images?
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
multi_modal_data["image"] = torch.cat(
(multi_modal_data["image"], mm_data["image"])
)
else:
# Use PIL image instead of image embeddings
multi_modal_data["image"].append(
await self.image_loader.load_image(mi.multimodal_input.image_url)
embeddings,
mi.image_grid_thw,
)
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
......
......@@ -19,6 +19,10 @@ from dynamo.vllm.multimodal_utils.model import (
construct_mm_data,
load_vision_model,
)
from dynamo.vllm.multimodal_utils.prefill_worker_utils import (
accumulate_embeddings,
load_embeddings,
)
from dynamo.vllm.multimodal_utils.protocol import (
MultiModalGroup,
MultiModalInput,
......@@ -51,4 +55,6 @@ __all__ = [
"vLLMMultimodalRequest",
"VLLMNativeEncoderRequest",
"VLLMNativeEncoderResponse",
"accumulate_embeddings",
"load_embeddings",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from typing import Any, Dict
import safetensors
import torch
import dynamo.nixl_connect as connect
from .model import construct_mm_data
from .protocol import MultiModalGroup
logger = logging.getLogger(__name__)
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
async def load_embeddings(
mi: MultiModalGroup,
embeddings_dtype: torch.dtype,
embeddings_device: str,
connector: connect.Connector | None,
) -> torch.Tensor:
"""Load pre-computed embedding tensor via local safetensors or NIXL RDMA.
Args:
mi: A single MultiModalGroup whose ``serialized_request`` field
contains either a local file path or NIXL RDMA metadata.
embeddings_dtype: Torch dtype for the tensor (used for RDMA path).
embeddings_device: Device string for the tensor (used for RDMA path).
connector: NIXL Connector for RDMA reads (required when TRANSFER_LOCAL=0).
Returns:
The embedding tensor loaded into CPU memory.
"""
if TRANSFER_LOCAL:
logger.info("PD: Loading local safetensors file")
return safetensors.torch.load_file(mi.serialized_request)["ec_cache"]
embeddings = torch.empty(
mi.embeddings_shape,
dtype=embeddings_dtype,
device=embeddings_device,
)
descriptor = connect.Descriptor(embeddings)
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op = await connector.begin_read(mi.serialized_request, descriptor)
await read_op.wait_for_completion()
return embeddings
def accumulate_embeddings(
multi_modal_data: Dict[str, Any],
model: str,
embeddings_dtype: torch.dtype,
embeddings: torch.Tensor,
image_grid_thw,
) -> None:
"""Construct model-specific mm_data from embeddings and merge into the
accumulated ``multi_modal_data`` dict (mutated in-place).
Handles both video (numpy conversion) and image modalities, including
the Qwen-VL dict-style embeddings with ``image_embeds`` + ``image_grid_thw``.
"""
if "video" in model.lower():
video_numpy = embeddings.numpy()
mm_data = construct_mm_data(
model,
embeddings_dtype,
video_numpy=video_numpy,
)
multi_modal_data["video"].append(mm_data["video"])
return
mm_data = construct_mm_data(
model,
embeddings_dtype,
image_embeds=embeddings,
image_grid_thw=image_grid_thw,
)
if isinstance(mm_data["image"], dict):
# Qwen-VL style: dict with image_embeds + image_grid_thw tensors
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
multi_modal_data["image"]["image_embeds"] = torch.cat(
(
multi_modal_data["image"]["image_embeds"],
mm_data["image"]["image_embeds"],
)
)
multi_modal_data["image"]["image_grid_thw"] = torch.cat(
(
multi_modal_data["image"]["image_grid_thw"],
mm_data["image"]["image_grid_thw"],
)
)
else:
# Plain tensor embeddings
logger.info(f"Get embedding of shape {mm_data['image'].shape}")
# [gluo FIXME] embedding with multiple images?
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
multi_modal_data["image"] = torch.cat(
(multi_modal_data["image"], mm_data["image"])
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment