"vscode:/vscode.git/clone" did not exist on "0d3ff44000b70666300eaa37a0340ffbc25a0984"
Unverified Commit 334cbd9b authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

refactor: vllm EPD refactor (#4994)


Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent 6f68be40
......@@ -34,7 +34,7 @@ from dynamo.vllm.multimodal_handlers import (
EncodeWorkerHandler,
MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler,
ProcessorHandler,
PreprocessedHandler,
VLLMEncodeWorkerHandler,
)
from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config
......@@ -676,13 +676,17 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
.client()
)
# Get prompt template from args (must be passed via environment or command line)
mm_prompt_template = config.mm_prompt_template
pd_worker_client = (
await runtime.namespace(config.namespace)
.component("backend")
.endpoint("generate")
.client()
)
handler = ProcessorHandler(
handler = PreprocessedHandler(
config.engine_args,
encode_worker_client,
mm_prompt_template,
pd_worker_client,
)
logger.info("Waiting for Encoder Worker Instances ...")
......@@ -690,7 +694,7 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
# Register the endpoint as entrypoint to a model
await register_llm(
ModelInput.Text, # Custom processor is used and this type bypasses SDK processor
ModelInput.Tokens,
ModelType.Chat,
generate_endpoint,
config.model,
......
......@@ -5,10 +5,8 @@ from dynamo.vllm.multimodal_handlers.encode_worker_handler import (
EncodeWorkerHandler,
VLLMEncodeWorkerHandler,
)
from dynamo.vllm.multimodal_handlers.preprocessor_handler import (
ECProcessorHandler,
ProcessorHandler,
)
from dynamo.vllm.multimodal_handlers.preprocessed_handler import PreprocessedHandler
from dynamo.vllm.multimodal_handlers.preprocessor_handler import ECProcessorHandler
from dynamo.vllm.multimodal_handlers.worker_handler import (
MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler,
......@@ -17,7 +15,7 @@ from dynamo.vllm.multimodal_handlers.worker_handler import (
__all__ = [
"EncodeWorkerHandler",
"VLLMEncodeWorkerHandler",
"ProcessorHandler",
"PreprocessedHandler",
"MultimodalPDWorkerHandler",
"MultimodalDecodeWorkerHandler",
"ECProcessorHandler",
......
......@@ -2,9 +2,12 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import shutil
import time
from typing import AsyncGenerator, AsyncIterator
import safetensors
from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TextPrompt
......@@ -16,10 +19,10 @@ from dynamo.runtime import Client, DistributedRuntime
from ..multimodal_utils import (
ImageLoader,
MyRequestOutput,
VLLMNativeEncoderRequest,
VLLMNativeEncoderResponse,
encode_image_embeddings,
get_embedding_hash,
get_encoder_components,
load_vision_model,
vLLMMultimodalRequest,
......@@ -42,6 +45,8 @@ except ImportError as e:
CACHE_SIZE_MAXIMUM = 8
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
class EncodeWorkerHandler:
def __init__(
......@@ -65,6 +70,10 @@ class EncodeWorkerHandler:
self.model, self.vision_model
)
self._connector = None
self._accumulated_time = 0.0
self._processed_requests = 0
self.readables = []
self.cached_embeddings = {}
def cleanup(self):
pass
......@@ -101,63 +110,107 @@ class EncodeWorkerHandler:
# 8. Yield the encode response.
try:
if not request.multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.")
image = await self.image_loader.load_image(
request.multimodal_input.image_url
)
logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
time_start = time.perf_counter()
for idx in range(len(request.multimodal_inputs)):
if not request.multimodal_inputs[idx].multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.")
image_url = request.multimodal_inputs[idx].multimodal_input.image_url
# see if we have local cache
if image_url in self.cached_embeddings:
(
embedding_key,
image_grid_thw,
embeddings_shape,
) = self.cached_embeddings[image_url]
# [gluo FIXME] need mechanism to clean up local files
request.multimodal_inputs[
idx
].serialized_request = (
f"/tmp/encoder_cache.{embedding_key}.safetensors"
)
request.multimodal_inputs[idx].multimodal_input.image_url = None
request.multimodal_inputs[idx].image_grid_thw = image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = embeddings_shape
continue
image = await self.image_loader.load_image(image_url)
logger.debug(
f"Processing image {image_url} for request: {{ id: {request_id} }}"
)
image_embeds = self.image_processor(images=image, return_tensors="pt")
# Encode the image embeddings using model-specific encoder
embeddings = encode_image_embeddings(
model_name=self.model,
image_embeds=image_embeds,
vision_encoder=self.vision_encoder,
projector=self.projector,
)
# Encode the image embeddings using model-specific encoder
embeddings = encode_image_embeddings(
model_name=self.model,
image_embeds=image_embeds,
vision_encoder=self.vision_encoder,
projector=self.projector,
)
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
# Move embeddings to CPU for NIXL transfer to avoid UCX/InfiniBand issues
embeddings_cpu = embeddings.cpu()
# Move embeddings to CPU for NIXL transfer to avoid UCX/InfiniBand issues
embeddings_cpu = embeddings.cpu()
request.multimodal_inputs[idx].image_grid_thw = image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = tuple(
embeddings.shape
)
request.image_grid_thw = image_grid_thw
request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings_cpu)
if TRANSFER_LOCAL:
embedding_key = get_embedding_hash(image_url)
logger.debug(
f"ENCODER: saving local safetensors file with key {embedding_key}, {embeddings_cpu.numel()} * {embeddings_cpu.element_size()} bytes"
)
tensors = {"ec_cache": embeddings_cpu}
safetensors.torch.save_file(
tensors, f"/tmp/encoder_cache.{embedding_key}.safetensors"
)
# [gluo FIXME] need mechanism to clean up local files
request.multimodal_inputs[
idx
].serialized_request = (
f"/tmp/encoder_cache.{embedding_key}.safetensors"
)
self.cached_embeddings[image_url] = (
embedding_key,
request.multimodal_inputs[idx].image_grid_thw,
request.multimodal_inputs[idx].embeddings_shape,
)
else:
# [gluo FIXME] nixl_connector path needs to be update to handle multiple embeddings
descriptor = connect.Descriptor(embeddings_cpu)
self.readables.append(
await self._connector.create_readable(descriptor)
)
request.multimodal_inputs[idx].serialized_request = self.readables[
-1
].metadata()
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None
request.multimodal_inputs[idx].multimodal_input.image_url = None
logger.debug(f"Request: {request.model_dump_json()}")
logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json(), context=context
)
await readable.wait_for_completion()
async for response in response_generator:
output = MyRequestOutput.model_validate_json(response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
).model_dump_json()
time_end = time.perf_counter()
self._accumulated_time += time_end - time_start
self._processed_requests += 1
logger.debug(
f"Encoded image(s) for request {{ id: {request_id} }} in {time_end - time_start:.4f} seconds. "
f"Average encoding time: {self._accumulated_time / self._processed_requests:.4f} seconds over {self._processed_requests} requests."
)
# Yield transformed request back
yield request.model_dump_json()
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
......@@ -240,7 +293,7 @@ class VLLMEncodeWorkerHandler:
try:
# Prompt can be a random string as the encoder is only interested in the multimodal data
prompt_dict = TextPrompt(
prompt="<image>", multi_modal_data={media_key: media}
prompt=request.prompt, multi_modal_data={media_key: media}
)
gen = self.engine_client.generate(
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import uuid
from collections import defaultdict
from enum import Enum
from typing import AsyncIterator, Final
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams as VllmSamplingParams
from dynamo.runtime import Client
from ..handlers import BaseWorkerHandler, build_sampling_params
from ..multimodal_utils import (
MultiModalGroup,
MultiModalInput,
MyRequestOutput,
PatchedTokensPrompt,
ProcessMixIn,
vLLMMultimodalRequest,
)
logger = logging.getLogger(__name__)
# Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url"
VIDEO_URL_KEY: Final = "video_url"
URL_VARIANT_KEY: Final = "Url"
DECODED_VARIANT_KEY: Final = "Decoded"
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class PreprocessedHandler(ProcessMixIn):
"""
vLLM pre and post processing for multimodal requests
"""
def __init__(
self,
engine_args: AsyncEngineArgs,
encode_worker_client: Client,
pd_worker_client: Client,
):
self.encode_worker_client = encode_worker_client
self.pd_worker_client = pd_worker_client
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
def cleanup(self):
pass
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request,
multimodal_inputs,
context,
):
# [gluo NOTE] panic for now as encoder here is for image only
if VIDEO_URL_KEY in multimodal_inputs or multimodal_inputs[VIDEO_URL_KEY]:
raise ValueError("Video URL not supported in encode worker yet")
request_id = str(uuid.uuid4().hex)
# Build sampling params from request using shared utility
sampling_params = build_sampling_params(
raw_request, self.default_sampling_params
)
# [gluo WIP] encoder doesn't really need any of this
encode_request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(prompt_token_ids=[]),
sampling_params=VllmSamplingParams(),
request_id=request_id,
multimodal_inputs=[],
)
# [gluo WIP] experiment with batching..
ENCODE_BATCH_SIZE = 1
encode_res_gen = []
for mm_type, urls in multimodal_inputs.items():
for url in urls:
multimodal_input = MultiModalInput()
if mm_type == IMAGE_URL_KEY:
multimodal_input.image_url = url
elif mm_type == VIDEO_URL_KEY:
multimodal_input.video_url = url
# [gluo NOTE] should not reach here due to earlier check
continue
encode_request.multimodal_inputs.append(
MultiModalGroup(multimodal_input=multimodal_input)
)
if len(encode_request.multimodal_inputs) >= ENCODE_BATCH_SIZE:
# model_dump_json() serializes the request to JSON string
# This API could accept Pydantic class, but SamplingParams
# in vLLMMultimodalRequest is not a Pydantic class and will
# cause TypeError: unsupported type SamplingParams
encode_res_gen.append(
await self.encode_worker_client.round_robin(
encode_request.model_dump_json()
)
)
encode_request.multimodal_inputs = []
if encode_request.multimodal_inputs:
encode_res_gen.append(
await self.encode_worker_client.round_robin(
encode_request.model_dump_json()
)
)
# Gather transformed requests
worker_request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(
prompt_token_ids=raw_request["token_ids"]
),
sampling_params=sampling_params,
request_id=request_id,
multimodal_inputs=[], # will be filled in next
)
for encode_res in encode_res_gen:
async for response in encode_res:
logger.debug(f"Received response from encode worker: {response}")
output = vLLMMultimodalRequest.model_validate_json(response.data())
worker_request.multimodal_inputs.extend(output.multimodal_inputs)
response_generator = await self.pd_worker_client.round_robin(
worker_request.model_dump_json(), context=context
)
# [gluo FIXME] <im_end> being returned
async for output in self._generate_responses(response_generator):
yield output
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
):
# [gluo WIP] modified from handler.py (BaseWorkerHandler.generate_tokens)
num_output_tokens_so_far = 0
try:
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
res = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if not res.outputs:
continue
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
# Extract logprobs for new tokens if available
log_probs, top_logprobs = BaseWorkerHandler._extract_logprobs(
output, num_output_tokens_so_far
)
if log_probs is not None:
out["log_probs"] = log_probs
if top_logprobs is not None:
out["top_logprobs"] = top_logprobs
if output.finish_reason:
out["finish_reason"] = output.finish_reason
out["completion_usage"] = BaseWorkerHandler._build_completion_usage(
request_output=res
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except asyncio.CancelledError:
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
raise GeneratorExit(
"Decode engine was shut down during token generation"
) from None
def _extract_multimodal_data(self, request):
"""
Extract and decode multimodal data from PreprocessedRequest.
"""
# [gluo NOTE] modified from components/src/dynamo/vllm/handlers.py
if "multi_modal_data" not in request or request["multi_modal_data"] is None:
return {}
# [gluo FIXME] add this security option
# Security check: reject multimodal data if not explicitly enabled
# if not self.enable_multimodal:
# raise ValueError(
# "Received multimodal data but multimodal processing is not enabled. "
# "Use --enable-multimodal flag to enable multimodal processing."
# )
mm_map = request["multi_modal_data"]
multimodal_inputs = defaultdict(list)
for mm_type in [IMAGE_URL_KEY, VIDEO_URL_KEY]:
for item in mm_map.get(mm_type, []):
if isinstance(item, dict) and URL_VARIANT_KEY in item:
multimodal_inputs[mm_type].append(item[URL_VARIANT_KEY])
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
# Decoded support from PRs #3971/#3988 (frontend decoding + NIXL transfer)
# Will contain NIXL metadata for direct memory access
# TODO: Implement NIXL read when PRs merge
logger.warning(
"Decoded multimodal data not yet supported in standard worker"
)
return multimodal_inputs
# The generate endpoint will be used by the frontend to handle incoming requests.
async def generate(self, request, context):
logger.debug(f"Got preprocessed request: {request}")
# Extract multimodal inputs for dispatching to encode worker
multimodal_inputs = self._extract_multimodal_data(request)
if not multimodal_inputs:
raise ValueError("Either image URL or video URL is required")
elif len(multimodal_inputs) > 1:
raise ValueError(
"Only one of image URL or video URL is supported per request"
)
async for response in self._generate(request, multimodal_inputs, context):
yield response
......@@ -19,6 +19,7 @@ from dynamo.runtime import Client
from ..multimodal_utils import (
ChatProcessor,
CompletionsProcessor,
MultiModalGroup,
MultiModalInput,
MultiModalRequest,
MyRequestOutput,
......@@ -264,6 +265,7 @@ class ECProcessorHandler(ProcessorHandler):
@staticmethod
def _create_encoder_request(
prompt: str,
mm_item: Dict[str, Any],
model: str,
request_id: str,
......@@ -282,6 +284,7 @@ class ECProcessorHandler(ProcessorHandler):
raise ValueError(f"Unsupported multimodal type: {mm_item.get('type')}")
return {
"prompt": prompt,
"request_id": request_id,
"multimodal_input": multimodal_input,
"modality": modality,
......@@ -289,6 +292,7 @@ class ECProcessorHandler(ProcessorHandler):
async def _encode_multimodal_items(
self,
prompt: str,
mm_items: List[Dict[str, Any]],
model: str,
request_id: str,
......@@ -312,6 +316,7 @@ class ECProcessorHandler(ProcessorHandler):
# Build encoder request
encoder_request = self._create_encoder_request(
prompt=prompt,
mm_item=mm_item,
model=model,
request_id=item_request_id,
......@@ -421,6 +426,7 @@ class ECProcessorHandler(ProcessorHandler):
)
try:
await self._encode_multimodal_items(
prompt=prompt,
mm_items=mm_items,
model=raw_request.model,
request_id=request_id,
......@@ -453,7 +459,9 @@ class ECProcessorHandler(ProcessorHandler):
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
multimodal_input=multimodal_input, # ✓ Keep this so vLLM can generate mm_hash
multimodal_inputs=[
MultiModalGroup(multimodal_input=multimodal_input)
], # ✓ Keep this so vLLM can generate mm_hash
)
logger.debug(
......
......@@ -3,7 +3,10 @@
import copy
import logging
import os
from collections import defaultdict
import safetensors
import torch
from vllm.inputs.data import TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM
......@@ -22,6 +25,8 @@ from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_m
logger = logging.getLogger(__name__)
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
"""Decode worker for disaggregated multimodal serving"""
......@@ -164,86 +169,118 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
# ECConnector consumer mode: vLLM loads embeddings automatically from disk
# We need to pass multimodal_input so vLLM can generate mm_hash and look up cache
if self.config.ec_consumer_mode:
logger.debug(
f"[{request.request_id}] ECConnector consumer mode: "
f"vLLM will load embeddings from cache using mm_hash"
)
# Use PIL image loading - vLLM will detect it's already in EC cache
# and load from disk instead of reprocessing
if request.multimodal_input and request.multimodal_input.image_url:
multi_modal_data = {
"image": await self.image_loader.load_image(
request.multimodal_input.image_url
multi_modal_data = defaultdict(list)
for mi in request.multimodal_inputs:
# ECConnector consumer mode: vLLM loads embeddings automatically from disk
# We need to pass multimodal_input so vLLM can generate mm_hash and look up cache
if self.config.ec_consumer_mode:
logger.debug(
f"[{request.request_id}] ECConnector consumer mode: "
f"vLLM will load embeddings from cache using mm_hash"
)
# Use PIL image loading - vLLM will detect it's already in EC cache
# and load from disk instead of reprocessing
if mi.multimodal_input.image_url:
multi_modal_data["image"].append(
await self.image_loader.load_image(
mi.multimodal_input.image_url
)
)
}
elif request.multimodal_input and request.multimodal_input.video_url:
# For video, load as image placeholder (vLLM will use EC cache)
multi_modal_data = {
"image": await self.image_loader.load_image(
request.multimodal_input.video_url
elif mi.multimodal_input.video_url:
# For video, load as image placeholder (vLLM will use EC cache)
multi_modal_data["image"].append(
await self.image_loader.load_image(
request.multimodal_input.video_url
)
)
}
else:
raise ValueError(
"ECConnector mode requires multimodal_input with image/video URL"
)
elif (
request.multimodal_input is not None
and request.multimodal_input.image_url is None
and request.multimodal_input.video_url is None
):
# Network transfer mode: receive embeddings via connector from encoder worker
# Create a descriptor based on the embedding shape.
embeddings = torch.empty(
request.embeddings_shape,
dtype=self.EMBEDDINGS_DTYPE,
device=self.EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
else:
raise ValueError(
"ECConnector mode requires multimodal_input with image/video URL"
)
elif (
mi.multimodal_input.image_url is None
and mi.multimodal_input.video_url is None
):
# Process embeddings using the connector
# Create a descriptor based on the embedding shape.
if TRANSFER_LOCAL:
logger.info("PD: Loading local safetensors file")
embeddings = safetensors.torch.load_file(mi.serialized_request)[
"ec_cache"
]
else:
embeddings = torch.empty(
mi.embeddings_shape,
dtype=self.EMBEDDINGS_DTYPE,
device=self.EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op = await self._connector.begin_read(
request.serialized_request, descriptor
)
await read_op.wait_for_completion()
if "video" in self.config.model.lower():
video_numpy = embeddings.numpy()
multi_modal_data = construct_mm_data(
self.config.model,
self.EMBEDDINGS_DTYPE,
video_numpy=video_numpy,
)
read_op = await self._connector.begin_read(
mi.serialized_request, descriptor
)
await read_op.wait_for_completion()
if "video" in self.config.model.lower():
video_numpy = embeddings.numpy()
mm_data = construct_mm_data(
self.config.model,
self.EMBEDDINGS_DTYPE,
video_numpy=video_numpy,
)
multi_modal_data["video"].append(mm_data["video"])
else:
mm_data = construct_mm_data(
self.config.model,
self.EMBEDDINGS_DTYPE,
image_embeds=embeddings,
image_grid_thw=mi.image_grid_thw,
)
if isinstance(mm_data["image"], dict):
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
# Merging tensors
multi_modal_data["image"]["image_embeds"] = torch.cat(
(
multi_modal_data["image"]["image_embeds"],
mm_data["image"]["image_embeds"],
)
)
multi_modal_data["image"]["image_grid_thw"] = torch.cat(
(
multi_modal_data["image"]["image_grid_thw"],
mm_data["image"]["image_grid_thw"],
)
)
else:
logger.info(f"Get embedding of shape {mm_data['image'].shape}")
# [gluo FIXME] embedding with multiple images?
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
multi_modal_data["image"] = torch.cat(
(multi_modal_data["image"], mm_data["image"])
)
else:
multi_modal_data = construct_mm_data(
self.config.model,
self.EMBEDDINGS_DTYPE,
image_embeds=embeddings,
image_grid_thw=request.image_grid_thw,
# Use PIL image instead of image embeddings
multi_modal_data["image"].append(
await self.image_loader.load_image(mi.multimodal_input.image_url)
)
elif request.multimodal_input is not None:
# Native mode: Use PIL image instead of image embeddings
multi_modal_data = {
"image": await self.image_loader.load_image(
request.multimodal_input.image_url
)
}
else:
raise ValueError(
"Invalid request: multimodal_input is None but not in ec_consumer_mode"
)
# Clear multimodal_input fields if present (not needed for engine)
if request.multimodal_input is not None:
request.multimodal_input.image_url = None
request.multimodal_input.video_url = None
request.serialized_request = None
# Remove the image features from the request as they are not required
request.multimodal_inputs = None
logger.info(f"Prepared multimodal data size: {len(multi_modal_data['image'])}")
logger.info(f"{multi_modal_data}")
# Deepcopy the request to avoid modifying the original
# when we adjust sampling params for prefill
pd_request = copy.deepcopy(request)
# Do prefill and remote decode if enable_disagg is true
......@@ -311,6 +348,10 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
logger.debug(
f"length of expanded prompt ids: {len(response.prompt_token_ids)}"
)
# logger.info(f"Response outputs: {response.outputs}")
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
......
......@@ -8,6 +8,7 @@ from dynamo.vllm.multimodal_utils.chat_processor import (
)
from dynamo.vllm.multimodal_utils.encode_utils import (
encode_image_embeddings,
get_embedding_hash,
get_encoder_components,
)
from dynamo.vllm.multimodal_utils.http_client import get_http_client
......@@ -18,9 +19,11 @@ from dynamo.vllm.multimodal_utils.model import (
load_vision_model,
)
from dynamo.vllm.multimodal_utils.protocol import (
MultiModalGroup,
MultiModalInput,
MultiModalRequest,
MyRequestOutput,
PatchedTokensPrompt,
VLLMNativeEncoderRequest,
VLLMNativeEncoderResponse,
vLLMMultimodalRequest,
......@@ -38,6 +41,9 @@ __all__ = [
"construct_mm_data",
"load_vision_model",
"MultiModalInput",
"MultiModalGroup",
"PatchedTokensPrompt",
"get_embedding_hash",
"MultiModalRequest",
"MyRequestOutput",
"vLLMMultimodalRequest",
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import json
import logging
from typing import Any, Dict, Optional
......@@ -25,6 +26,18 @@ from .model import SupportedModels, is_model_supported, is_qwen_vl_model
logger = logging.getLogger(__name__)
def get_embedding_hash(key: str) -> str:
"""
Generate a unique hash key for storing/retrieving image embeddings.
Args:
key: The base key string (e.g., image URL or identifier)
Returns:
A unique hash string for the given key.
"""
return hashlib.sha256(key.encode()).hexdigest()
def get_qwen_image_features(
vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any]
) -> torch.Tensor:
......
......@@ -29,6 +29,7 @@ class SupportedModels:
LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf"
QWEN_2_VL_2B = "Qwen/Qwen2-VL-2B-Instruct"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
QWEN_2_5_VL_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
......@@ -105,6 +106,7 @@ def is_model_supported(model_name: str, supported_model: str) -> bool:
QWEN_VL_MODELS = [
SupportedModels.QWEN_2_VL_2B,
SupportedModels.QWEN_2_5_VL_7B,
SupportedModels.QWEN_2_5_VL_3B,
]
......
......@@ -140,20 +140,26 @@ class MultiModalInput(BaseModel):
video_url: Optional[str] = None
class vLLMMultimodalRequest(vLLMGenerateRequest):
class MultiModalGroup(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput)
image_grid_thw: Optional[List[Any]] = None
embeddings_shape: Optional[
Union[Tuple[int, int, int], Tuple[int, int, int, int]]
] = None
serialized_request: Optional[connect.RdmaMetadata] = None
serialized_request: Optional[connect.RdmaMetadata | str] = None
class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
multimodal_inputs: List[MultiModalGroup] = Field(default_factory=list)
class VLLMNativeEncoderRequest(BaseModel):
"""Request for vLLM-native encoder worker using ECConnector"""
request_id: str
prompt: str
multimodal_input: MultiModalInput
modality: Literal["image", "video", "audio"]
batch_items: Optional[List[MultiModalInput]] = None # For future batch processing
......
......@@ -16,9 +16,7 @@ set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="llava-hf/llava-1.5-7b-hf"
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
PROVIDED_PROMPT_TEMPLATE=""
MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct"
SINGLE_GPU=false
# Parse command line arguments
......@@ -28,10 +26,6 @@ while [[ $# -gt 0 ]]; do
MODEL_NAME=$2
shift 2
;;
--prompt-template)
PROVIDED_PROMPT_TEMPLATE=$2
shift 2
;;
--single-gpu)
SINGLE_GPU=true
shift
......@@ -40,7 +34,6 @@ while [[ $# -gt 0 ]]; do
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates."
echo " --single-gpu Run both encode and PD workers on GPU 0 (for pre-merge CI)"
echo " -h, --help Show this help message"
exit 0
......@@ -53,22 +46,6 @@ while [[ $# -gt 0 ]]; do
esac
done
# Set PROMPT_TEMPLATE based on the MODEL_NAME
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
elif [[ "$MODEL_NAME" == "microsoft/Phi-3.5-vision-instruct" ]]; then
PROMPT_TEMPLATE="<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
elif [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]] || [[ "$MODEL_NAME" == "Qwen/Qwen2-VL-2B-Instruct" ]]; then
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
else
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
echo "Please provide a prompt template using --prompt-template option."
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
exit 1
fi
# Start frontend (HTTP endpoint)
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
python -m dynamo.frontend &
......@@ -78,14 +55,14 @@ python -m dynamo.frontend &
# Multi-GPU mode: Each worker gets its own GPU, so use higher memory settings
EXTRA_ARGS=""
if [[ "$SINGLE_GPU" == "true" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.3 --max-model-len 3072 --enforce-eager"
EXTRA_ARGS="--gpu-memory-utilization 0.5 --enforce-eager --max-model-len 30426"
else
# Multi-GPU mode: standard memory settings
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 4096"
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 34096"
fi
# Start processor (Python-based preprocessing, handles prompt templating)
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME &
# run E/P/D workers
# Use single GPU (GPU 0) for pre-merge CI, otherwise use GPU 0 for encode and GPU 1 for PD
......
......@@ -12,7 +12,7 @@ MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
python -m dynamo.frontend &
# run processor
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME --mm-prompt-template "<|image|>\n<prompt>" &
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME &
# Llama 4 doesn't support image embedding input, so use encode+prefill worker
# that handles image encoding inline
python -m dynamo.vllm --multimodal-encode-prefill-worker --enable-multimodal --model $MODEL_NAME --tensor-parallel-size=8 --max-model-len=208960 --gpu-memory-utilization 0.80 &
......
......@@ -6,8 +6,6 @@ trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="llava-hf/llava-1.5-7b-hf"
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
PROVIDED_PROMPT_TEMPLATE=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
......@@ -16,10 +14,6 @@ while [[ $# -gt 0 ]]; do
MODEL_NAME=$2
shift 2
;;
--prompt-template)
PROVIDED_PROMPT_TEMPLATE=$2
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo ""
......@@ -27,7 +21,6 @@ while [[ $# -gt 0 ]]; do
echo ""
echo "Options:"
echo " --model <model_name> Specify the VLM model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use"
echo " LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates"
echo " -h, --help Show this help message"
echo ""
......@@ -46,27 +39,11 @@ while [[ $# -gt 0 ]]; do
esac
done
# Set PROMPT_TEMPLATE based on the MODEL_NAME
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
elif [[ "$MODEL_NAME" == "microsoft/Phi-3.5-vision-instruct" ]]; then
PROMPT_TEMPLATE="<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
elif [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
else
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
echo "Please provide a prompt template using --prompt-template option."
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
exit 1
fi
echo "=================================================="
echo "Disaggregated Multimodal Serving"
echo "=================================================="
echo "Model: $MODEL_NAME"
echo "Prompt Template: $PROMPT_TEMPLATE"
echo "=================================================="
......@@ -77,7 +54,7 @@ python -m dynamo.frontend &
# Start processor
echo "Starting processor..."
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME &
EXTRA_ARGS=""
......
......@@ -49,7 +49,7 @@ if [[ $HEAD_NODE -eq 1 ]]; then
python -m dynamo.frontend &
# run processor
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME --mm-prompt-template "<|image|>\n<prompt>" &
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME &
# Llama 4 doesn't support image embedding input, so the prefill worker will also
# handle image encoding inline.
......
......@@ -296,7 +296,9 @@ vllm_configs = {
},
],
repeat_count=1,
expected_response=["purple"],
# With proper prompt templating, the model actually only returns "green",
# verified behavior with native vLLM.
expected_response=["green"],
temperature=0.0,
max_tokens=100,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment