"vscode:/vscode.git/clone" did not exist on "f46720c996064a997524a9ba419950492a72aaab"
Unverified Commit 20ccc9b2 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

refactor: add prefill_worker_utils in vLLM (#6017)

parent 1aab7f6b
...@@ -3,10 +3,9 @@ ...@@ -3,10 +3,9 @@
import copy import copy
import logging import logging
import os
from collections import defaultdict from collections import defaultdict
from typing import Any
import safetensors
import torch import torch
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
...@@ -15,18 +14,15 @@ import dynamo.nixl_connect as connect ...@@ -15,18 +14,15 @@ import dynamo.nixl_connect as connect
from dynamo.runtime import Client, Component, DistributedRuntime from dynamo.runtime import Client, Component, DistributedRuntime
from ..handlers import BaseWorkerHandler from ..handlers import BaseWorkerHandler
from ..multimodal_utils import ( from ..multimodal_utils import ImageLoader, MyRequestOutput, vLLMMultimodalRequest
ImageLoader,
MyRequestOutput,
construct_mm_data,
vLLMMultimodalRequest,
)
from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
from ..multimodal_utils.prefill_worker_utils import (
accumulate_embeddings,
load_embeddings,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
class MultimodalDecodeWorkerHandler(BaseWorkerHandler): class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
"""Decode worker for disaggregated multimodal serving""" """Decode worker for disaggregated multimodal serving"""
...@@ -181,108 +177,29 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -181,108 +177,29 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request = vLLMMultimodalRequest.model_validate(request) request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
multi_modal_data = defaultdict(list) multi_modal_data: dict[str, Any] = defaultdict(list)
for mi in request.multimodal_inputs: for mi in request.multimodal_inputs:
# ECConnector consumer mode: vLLM loads embeddings automatically from disk
# We need to pass multimodal_input so vLLM can generate mm_hash and look up cache
if self.config.ec_consumer_mode:
logger.debug(
f"[{request.request_id}] ECConnector consumer mode: "
f"vLLM will load embeddings from cache using mm_hash"
)
# Use PIL image loading - vLLM will detect it's already in EC cache
# and load from disk instead of reprocessing
if mi.multimodal_input.image_url: if mi.multimodal_input.image_url:
# PIL image path — used by both EC consumer mode
# (vLLM looks up cached embeddings via mm_hash) and
# non-disaggregated mode (vLLM encodes inline).
multi_modal_data["image"].append( multi_modal_data["image"].append(
await self.image_loader.load_image( await self.image_loader.load_image(mi.multimodal_input.image_url)
mi.multimodal_input.image_url
)
)
elif mi.multimodal_input.video_url:
# For video, load as image placeholder (vLLM will use EC cache)
multi_modal_data["image"].append(
await self.image_loader.load_image(
request.multimodal_input.video_url
)
)
else:
raise ValueError(
"ECConnector mode requires multimodal_input with image/video URL"
) )
elif (
mi.multimodal_input.image_url is None
and mi.multimodal_input.video_url is None
):
# Process embeddings using the connector
# Create a descriptor based on the embedding shape.
if TRANSFER_LOCAL:
logger.info("PD: Loading local safetensors file")
embeddings = safetensors.torch.load_file(mi.serialized_request)[
"ec_cache"
]
else: else:
embeddings = torch.empty( # Pre-computed embeddings via NIXL RDMA or local safetensors
mi.embeddings_shape, embeddings = await load_embeddings(
dtype=self.EMBEDDINGS_DTYPE, mi,
device=self.EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op = await self._connector.begin_read(
mi.serialized_request, descriptor
)
await read_op.wait_for_completion()
if "video" in self.config.model.lower():
video_numpy = embeddings.numpy()
mm_data = construct_mm_data(
self.config.model,
self.EMBEDDINGS_DTYPE, self.EMBEDDINGS_DTYPE,
video_numpy=video_numpy, self.EMBEDDINGS_DEVICE,
self._connector,
) )
multi_modal_data["video"].append(mm_data["video"]) accumulate_embeddings(
else: multi_modal_data,
mm_data = construct_mm_data(
self.config.model, self.config.model,
self.EMBEDDINGS_DTYPE, self.EMBEDDINGS_DTYPE,
image_embeds=embeddings, embeddings,
image_grid_thw=mi.image_grid_thw, mi.image_grid_thw,
)
if isinstance(mm_data["image"], dict):
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
# Merging tensors
multi_modal_data["image"]["image_embeds"] = torch.cat(
(
multi_modal_data["image"]["image_embeds"],
mm_data["image"]["image_embeds"],
)
)
multi_modal_data["image"]["image_grid_thw"] = torch.cat(
(
multi_modal_data["image"]["image_grid_thw"],
mm_data["image"]["image_grid_thw"],
)
)
else:
logger.info(f"Get embedding of shape {mm_data['image'].shape}")
# [gluo FIXME] embedding with multiple images?
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
multi_modal_data["image"] = torch.cat(
(multi_modal_data["image"], mm_data["image"])
)
else:
# Use PIL image instead of image embeddings
multi_modal_data["image"].append(
await self.image_loader.load_image(mi.multimodal_input.image_url)
) )
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape # For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
......
...@@ -19,6 +19,10 @@ from dynamo.vllm.multimodal_utils.model import ( ...@@ -19,6 +19,10 @@ from dynamo.vllm.multimodal_utils.model import (
construct_mm_data, construct_mm_data,
load_vision_model, load_vision_model,
) )
from dynamo.vllm.multimodal_utils.prefill_worker_utils import (
accumulate_embeddings,
load_embeddings,
)
from dynamo.vllm.multimodal_utils.protocol import ( from dynamo.vllm.multimodal_utils.protocol import (
MultiModalGroup, MultiModalGroup,
MultiModalInput, MultiModalInput,
...@@ -51,4 +55,6 @@ __all__ = [ ...@@ -51,4 +55,6 @@ __all__ = [
"vLLMMultimodalRequest", "vLLMMultimodalRequest",
"VLLMNativeEncoderRequest", "VLLMNativeEncoderRequest",
"VLLMNativeEncoderResponse", "VLLMNativeEncoderResponse",
"accumulate_embeddings",
"load_embeddings",
] ]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from typing import Any, Dict
import safetensors
import torch
import dynamo.nixl_connect as connect
from .model import construct_mm_data
from .protocol import MultiModalGroup
logger = logging.getLogger(__name__)
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
async def load_embeddings(
mi: MultiModalGroup,
embeddings_dtype: torch.dtype,
embeddings_device: str,
connector: connect.Connector | None,
) -> torch.Tensor:
"""Load pre-computed embedding tensor via local safetensors or NIXL RDMA.
Args:
mi: A single MultiModalGroup whose ``serialized_request`` field
contains either a local file path or NIXL RDMA metadata.
embeddings_dtype: Torch dtype for the tensor (used for RDMA path).
embeddings_device: Device string for the tensor (used for RDMA path).
connector: NIXL Connector for RDMA reads (required when TRANSFER_LOCAL=0).
Returns:
The embedding tensor loaded into CPU memory.
"""
if TRANSFER_LOCAL:
logger.info("PD: Loading local safetensors file")
return safetensors.torch.load_file(mi.serialized_request)["ec_cache"]
embeddings = torch.empty(
mi.embeddings_shape,
dtype=embeddings_dtype,
device=embeddings_device,
)
descriptor = connect.Descriptor(embeddings)
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op = await connector.begin_read(mi.serialized_request, descriptor)
await read_op.wait_for_completion()
return embeddings
def accumulate_embeddings(
multi_modal_data: Dict[str, Any],
model: str,
embeddings_dtype: torch.dtype,
embeddings: torch.Tensor,
image_grid_thw,
) -> None:
"""Construct model-specific mm_data from embeddings and merge into the
accumulated ``multi_modal_data`` dict (mutated in-place).
Handles both video (numpy conversion) and image modalities, including
the Qwen-VL dict-style embeddings with ``image_embeds`` + ``image_grid_thw``.
"""
if "video" in model.lower():
video_numpy = embeddings.numpy()
mm_data = construct_mm_data(
model,
embeddings_dtype,
video_numpy=video_numpy,
)
multi_modal_data["video"].append(mm_data["video"])
return
mm_data = construct_mm_data(
model,
embeddings_dtype,
image_embeds=embeddings,
image_grid_thw=image_grid_thw,
)
if isinstance(mm_data["image"], dict):
# Qwen-VL style: dict with image_embeds + image_grid_thw tensors
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
multi_modal_data["image"]["image_embeds"] = torch.cat(
(
multi_modal_data["image"]["image_embeds"],
mm_data["image"]["image_embeds"],
)
)
multi_modal_data["image"]["image_grid_thw"] = torch.cat(
(
multi_modal_data["image"]["image_grid_thw"],
mm_data["image"]["image_grid_thw"],
)
)
else:
# Plain tensor embeddings
logger.info(f"Get embedding of shape {mm_data['image'].shape}")
# [gluo FIXME] embedding with multiple images?
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
multi_modal_data["image"] = torch.cat(
(multi_modal_data["image"], mm_data["image"])
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment