Unverified Commit 334cbd9b authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

refactor: vllm EPD refactor (#4994)


Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent 6f68be40
...@@ -34,7 +34,7 @@ from dynamo.vllm.multimodal_handlers import ( ...@@ -34,7 +34,7 @@ from dynamo.vllm.multimodal_handlers import (
EncodeWorkerHandler, EncodeWorkerHandler,
MultimodalDecodeWorkerHandler, MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler, MultimodalPDWorkerHandler,
ProcessorHandler, PreprocessedHandler,
VLLMEncodeWorkerHandler, VLLMEncodeWorkerHandler,
) )
from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config
...@@ -676,13 +676,17 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config) ...@@ -676,13 +676,17 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
.client() .client()
) )
# Get prompt template from args (must be passed via environment or command line) pd_worker_client = (
mm_prompt_template = config.mm_prompt_template await runtime.namespace(config.namespace)
.component("backend")
.endpoint("generate")
.client()
)
handler = ProcessorHandler( handler = PreprocessedHandler(
config.engine_args, config.engine_args,
encode_worker_client, encode_worker_client,
mm_prompt_template, pd_worker_client,
) )
logger.info("Waiting for Encoder Worker Instances ...") logger.info("Waiting for Encoder Worker Instances ...")
...@@ -690,7 +694,7 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config) ...@@ -690,7 +694,7 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
# Register the endpoint as entrypoint to a model # Register the endpoint as entrypoint to a model
await register_llm( await register_llm(
ModelInput.Text, # Custom processor is used and this type bypasses SDK processor ModelInput.Tokens,
ModelType.Chat, ModelType.Chat,
generate_endpoint, generate_endpoint,
config.model, config.model,
......
...@@ -5,10 +5,8 @@ from dynamo.vllm.multimodal_handlers.encode_worker_handler import ( ...@@ -5,10 +5,8 @@ from dynamo.vllm.multimodal_handlers.encode_worker_handler import (
EncodeWorkerHandler, EncodeWorkerHandler,
VLLMEncodeWorkerHandler, VLLMEncodeWorkerHandler,
) )
from dynamo.vllm.multimodal_handlers.preprocessor_handler import ( from dynamo.vllm.multimodal_handlers.preprocessed_handler import PreprocessedHandler
ECProcessorHandler, from dynamo.vllm.multimodal_handlers.preprocessor_handler import ECProcessorHandler
ProcessorHandler,
)
from dynamo.vllm.multimodal_handlers.worker_handler import ( from dynamo.vllm.multimodal_handlers.worker_handler import (
MultimodalDecodeWorkerHandler, MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler, MultimodalPDWorkerHandler,
...@@ -17,7 +15,7 @@ from dynamo.vllm.multimodal_handlers.worker_handler import ( ...@@ -17,7 +15,7 @@ from dynamo.vllm.multimodal_handlers.worker_handler import (
__all__ = [ __all__ = [
"EncodeWorkerHandler", "EncodeWorkerHandler",
"VLLMEncodeWorkerHandler", "VLLMEncodeWorkerHandler",
"ProcessorHandler", "PreprocessedHandler",
"MultimodalPDWorkerHandler", "MultimodalPDWorkerHandler",
"MultimodalDecodeWorkerHandler", "MultimodalDecodeWorkerHandler",
"ECProcessorHandler", "ECProcessorHandler",
......
...@@ -2,9 +2,12 @@ ...@@ -2,9 +2,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
import os
import shutil import shutil
import time
from typing import AsyncGenerator, AsyncIterator from typing import AsyncGenerator, AsyncIterator
import safetensors
from transformers import AutoImageProcessor from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TextPrompt from vllm.inputs.data import TextPrompt
...@@ -16,10 +19,10 @@ from dynamo.runtime import Client, DistributedRuntime ...@@ -16,10 +19,10 @@ from dynamo.runtime import Client, DistributedRuntime
from ..multimodal_utils import ( from ..multimodal_utils import (
ImageLoader, ImageLoader,
MyRequestOutput,
VLLMNativeEncoderRequest, VLLMNativeEncoderRequest,
VLLMNativeEncoderResponse, VLLMNativeEncoderResponse,
encode_image_embeddings, encode_image_embeddings,
get_embedding_hash,
get_encoder_components, get_encoder_components,
load_vision_model, load_vision_model,
vLLMMultimodalRequest, vLLMMultimodalRequest,
...@@ -42,6 +45,8 @@ except ImportError as e: ...@@ -42,6 +45,8 @@ except ImportError as e:
CACHE_SIZE_MAXIMUM = 8 CACHE_SIZE_MAXIMUM = 8
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
class EncodeWorkerHandler: class EncodeWorkerHandler:
def __init__( def __init__(
...@@ -65,6 +70,10 @@ class EncodeWorkerHandler: ...@@ -65,6 +70,10 @@ class EncodeWorkerHandler:
self.model, self.vision_model self.model, self.vision_model
) )
self._connector = None self._connector = None
self._accumulated_time = 0.0
self._processed_requests = 0
self.readables = []
self.cached_embeddings = {}
def cleanup(self): def cleanup(self):
pass pass
...@@ -101,63 +110,107 @@ class EncodeWorkerHandler: ...@@ -101,63 +110,107 @@ class EncodeWorkerHandler:
# 8. Yield the encode response. # 8. Yield the encode response.
try: try:
if not request.multimodal_input.image_url: time_start = time.perf_counter()
raise ValueError("image_url is required for the encode worker.") for idx in range(len(request.multimodal_inputs)):
if not request.multimodal_inputs[idx].multimodal_input.image_url:
image = await self.image_loader.load_image( raise ValueError("image_url is required for the encode worker.")
request.multimodal_input.image_url
) image_url = request.multimodal_inputs[idx].multimodal_input.image_url
# see if we have local cache
logger.debug(f"Processing image for request: {{ id: {request_id} }}") if image_url in self.cached_embeddings:
image_embeds = self.image_processor(images=image, return_tensors="pt") (
embedding_key,
image_grid_thw,
embeddings_shape,
) = self.cached_embeddings[image_url]
# [gluo FIXME] need mechanism to clean up local files
request.multimodal_inputs[
idx
].serialized_request = (
f"/tmp/encoder_cache.{embedding_key}.safetensors"
)
request.multimodal_inputs[idx].multimodal_input.image_url = None
request.multimodal_inputs[idx].image_grid_thw = image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = embeddings_shape
continue
image = await self.image_loader.load_image(image_url)
logger.debug(
f"Processing image {image_url} for request: {{ id: {request_id} }}"
)
image_embeds = self.image_processor(images=image, return_tensors="pt")
# Encode the image embeddings using model-specific encoder
embeddings = encode_image_embeddings(
model_name=self.model,
image_embeds=image_embeds,
vision_encoder=self.vision_encoder,
projector=self.projector,
)
# Encode the image embeddings using model-specific encoder image_grid_thw = (
embeddings = encode_image_embeddings( image_embeds["image_grid_thw"].tolist()
model_name=self.model, if "image_grid_thw" in image_embeds
image_embeds=image_embeds, else None
vision_encoder=self.vision_encoder, )
projector=self.projector, logger.debug(
) f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
image_grid_thw = ( # Move embeddings to CPU for NIXL transfer to avoid UCX/InfiniBand issues
image_embeds["image_grid_thw"].tolist() embeddings_cpu = embeddings.cpu()
if "image_grid_thw" in image_embeds
else None
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
# Move embeddings to CPU for NIXL transfer to avoid UCX/InfiniBand issues request.multimodal_inputs[idx].image_grid_thw = image_grid_thw
embeddings_cpu = embeddings.cpu() request.multimodal_inputs[idx].embeddings_shape = tuple(
embeddings.shape
)
request.image_grid_thw = image_grid_thw if TRANSFER_LOCAL:
request.embeddings_shape = tuple(embeddings.shape) embedding_key = get_embedding_hash(image_url)
descriptor = connect.Descriptor(embeddings_cpu) logger.debug(
f"ENCODER: saving local safetensors file with key {embedding_key}, {embeddings_cpu.numel()} * {embeddings_cpu.element_size()} bytes"
)
tensors = {"ec_cache": embeddings_cpu}
safetensors.torch.save_file(
tensors, f"/tmp/encoder_cache.{embedding_key}.safetensors"
)
# [gluo FIXME] need mechanism to clean up local files
request.multimodal_inputs[
idx
].serialized_request = (
f"/tmp/encoder_cache.{embedding_key}.safetensors"
)
self.cached_embeddings[image_url] = (
embedding_key,
request.multimodal_inputs[idx].image_grid_thw,
request.multimodal_inputs[idx].embeddings_shape,
)
else:
# [gluo FIXME] nixl_connector path needs to be update to handle multiple embeddings
descriptor = connect.Descriptor(embeddings_cpu)
self.readables.append(
await self._connector.create_readable(descriptor)
)
request.multimodal_inputs[idx].serialized_request = self.readables[
-1
].metadata()
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings. # Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None request.multimodal_inputs[idx].multimodal_input.image_url = None
logger.debug(f"Request: {request.model_dump_json()}") logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator time_end = time.perf_counter()
response_generator = await self.pd_worker_client.round_robin( self._accumulated_time += time_end - time_start
request.model_dump_json(), context=context self._processed_requests += 1
) logger.debug(
await readable.wait_for_completion() f"Encoded image(s) for request {{ id: {request_id} }} in {time_end - time_start:.4f} seconds. "
f"Average encoding time: {self._accumulated_time / self._processed_requests:.4f} seconds over {self._processed_requests} requests."
async for response in response_generator: )
output = MyRequestOutput.model_validate_json(response.data())
yield MyRequestOutput( # Yield transformed request back
request_id=output.request_id, yield request.model_dump_json()
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
).model_dump_json()
except Exception as e: except Exception as e:
logger.error(f"Error processing request {request_id}: {e}") logger.error(f"Error processing request {request_id}: {e}")
...@@ -240,7 +293,7 @@ class VLLMEncodeWorkerHandler: ...@@ -240,7 +293,7 @@ class VLLMEncodeWorkerHandler:
try: try:
# Prompt can be a random string as the encoder is only interested in the multimodal data # Prompt can be a random string as the encoder is only interested in the multimodal data
prompt_dict = TextPrompt( prompt_dict = TextPrompt(
prompt="<image>", multi_modal_data={media_key: media} prompt=request.prompt, multi_modal_data={media_key: media}
) )
gen = self.engine_client.generate( gen = self.engine_client.generate(
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import uuid
from collections import defaultdict
from enum import Enum
from typing import AsyncIterator, Final
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams as VllmSamplingParams
from dynamo.runtime import Client
from ..handlers import BaseWorkerHandler, build_sampling_params
from ..multimodal_utils import (
MultiModalGroup,
MultiModalInput,
MyRequestOutput,
PatchedTokensPrompt,
ProcessMixIn,
vLLMMultimodalRequest,
)
logger = logging.getLogger(__name__)
# Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url"
VIDEO_URL_KEY: Final = "video_url"
URL_VARIANT_KEY: Final = "Url"
DECODED_VARIANT_KEY: Final = "Decoded"
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class PreprocessedHandler(ProcessMixIn):
"""
vLLM pre and post processing for multimodal requests
"""
def __init__(
self,
engine_args: AsyncEngineArgs,
encode_worker_client: Client,
pd_worker_client: Client,
):
self.encode_worker_client = encode_worker_client
self.pd_worker_client = pd_worker_client
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
def cleanup(self):
pass
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request,
multimodal_inputs,
context,
):
# [gluo NOTE] panic for now as encoder here is for image only
if VIDEO_URL_KEY in multimodal_inputs or multimodal_inputs[VIDEO_URL_KEY]:
raise ValueError("Video URL not supported in encode worker yet")
request_id = str(uuid.uuid4().hex)
# Build sampling params from request using shared utility
sampling_params = build_sampling_params(
raw_request, self.default_sampling_params
)
# [gluo WIP] encoder doesn't really need any of this
encode_request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(prompt_token_ids=[]),
sampling_params=VllmSamplingParams(),
request_id=request_id,
multimodal_inputs=[],
)
# [gluo WIP] experiment with batching..
ENCODE_BATCH_SIZE = 1
encode_res_gen = []
for mm_type, urls in multimodal_inputs.items():
for url in urls:
multimodal_input = MultiModalInput()
if mm_type == IMAGE_URL_KEY:
multimodal_input.image_url = url
elif mm_type == VIDEO_URL_KEY:
multimodal_input.video_url = url
# [gluo NOTE] should not reach here due to earlier check
continue
encode_request.multimodal_inputs.append(
MultiModalGroup(multimodal_input=multimodal_input)
)
if len(encode_request.multimodal_inputs) >= ENCODE_BATCH_SIZE:
# model_dump_json() serializes the request to JSON string
# This API could accept Pydantic class, but SamplingParams
# in vLLMMultimodalRequest is not a Pydantic class and will
# cause TypeError: unsupported type SamplingParams
encode_res_gen.append(
await self.encode_worker_client.round_robin(
encode_request.model_dump_json()
)
)
encode_request.multimodal_inputs = []
if encode_request.multimodal_inputs:
encode_res_gen.append(
await self.encode_worker_client.round_robin(
encode_request.model_dump_json()
)
)
# Gather transformed requests
worker_request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(
prompt_token_ids=raw_request["token_ids"]
),
sampling_params=sampling_params,
request_id=request_id,
multimodal_inputs=[], # will be filled in next
)
for encode_res in encode_res_gen:
async for response in encode_res:
logger.debug(f"Received response from encode worker: {response}")
output = vLLMMultimodalRequest.model_validate_json(response.data())
worker_request.multimodal_inputs.extend(output.multimodal_inputs)
response_generator = await self.pd_worker_client.round_robin(
worker_request.model_dump_json(), context=context
)
# [gluo FIXME] <im_end> being returned
async for output in self._generate_responses(response_generator):
yield output
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
):
# [gluo WIP] modified from handler.py (BaseWorkerHandler.generate_tokens)
num_output_tokens_so_far = 0
try:
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
res = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if not res.outputs:
continue
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
# Extract logprobs for new tokens if available
log_probs, top_logprobs = BaseWorkerHandler._extract_logprobs(
output, num_output_tokens_so_far
)
if log_probs is not None:
out["log_probs"] = log_probs
if top_logprobs is not None:
out["top_logprobs"] = top_logprobs
if output.finish_reason:
out["finish_reason"] = output.finish_reason
out["completion_usage"] = BaseWorkerHandler._build_completion_usage(
request_output=res
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except asyncio.CancelledError:
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
raise GeneratorExit(
"Decode engine was shut down during token generation"
) from None
def _extract_multimodal_data(self, request):
"""
Extract and decode multimodal data from PreprocessedRequest.
"""
# [gluo NOTE] modified from components/src/dynamo/vllm/handlers.py
if "multi_modal_data" not in request or request["multi_modal_data"] is None:
return {}
# [gluo FIXME] add this security option
# Security check: reject multimodal data if not explicitly enabled
# if not self.enable_multimodal:
# raise ValueError(
# "Received multimodal data but multimodal processing is not enabled. "
# "Use --enable-multimodal flag to enable multimodal processing."
# )
mm_map = request["multi_modal_data"]
multimodal_inputs = defaultdict(list)
for mm_type in [IMAGE_URL_KEY, VIDEO_URL_KEY]:
for item in mm_map.get(mm_type, []):
if isinstance(item, dict) and URL_VARIANT_KEY in item:
multimodal_inputs[mm_type].append(item[URL_VARIANT_KEY])
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
# Decoded support from PRs #3971/#3988 (frontend decoding + NIXL transfer)
# Will contain NIXL metadata for direct memory access
# TODO: Implement NIXL read when PRs merge
logger.warning(
"Decoded multimodal data not yet supported in standard worker"
)
return multimodal_inputs
# The generate endpoint will be used by the frontend to handle incoming requests.
async def generate(self, request, context):
logger.debug(f"Got preprocessed request: {request}")
# Extract multimodal inputs for dispatching to encode worker
multimodal_inputs = self._extract_multimodal_data(request)
if not multimodal_inputs:
raise ValueError("Either image URL or video URL is required")
elif len(multimodal_inputs) > 1:
raise ValueError(
"Only one of image URL or video URL is supported per request"
)
async for response in self._generate(request, multimodal_inputs, context):
yield response
...@@ -19,6 +19,7 @@ from dynamo.runtime import Client ...@@ -19,6 +19,7 @@ from dynamo.runtime import Client
from ..multimodal_utils import ( from ..multimodal_utils import (
ChatProcessor, ChatProcessor,
CompletionsProcessor, CompletionsProcessor,
MultiModalGroup,
MultiModalInput, MultiModalInput,
MultiModalRequest, MultiModalRequest,
MyRequestOutput, MyRequestOutput,
...@@ -264,6 +265,7 @@ class ECProcessorHandler(ProcessorHandler): ...@@ -264,6 +265,7 @@ class ECProcessorHandler(ProcessorHandler):
@staticmethod @staticmethod
def _create_encoder_request( def _create_encoder_request(
prompt: str,
mm_item: Dict[str, Any], mm_item: Dict[str, Any],
model: str, model: str,
request_id: str, request_id: str,
...@@ -282,6 +284,7 @@ class ECProcessorHandler(ProcessorHandler): ...@@ -282,6 +284,7 @@ class ECProcessorHandler(ProcessorHandler):
raise ValueError(f"Unsupported multimodal type: {mm_item.get('type')}") raise ValueError(f"Unsupported multimodal type: {mm_item.get('type')}")
return { return {
"prompt": prompt,
"request_id": request_id, "request_id": request_id,
"multimodal_input": multimodal_input, "multimodal_input": multimodal_input,
"modality": modality, "modality": modality,
...@@ -289,6 +292,7 @@ class ECProcessorHandler(ProcessorHandler): ...@@ -289,6 +292,7 @@ class ECProcessorHandler(ProcessorHandler):
async def _encode_multimodal_items( async def _encode_multimodal_items(
self, self,
prompt: str,
mm_items: List[Dict[str, Any]], mm_items: List[Dict[str, Any]],
model: str, model: str,
request_id: str, request_id: str,
...@@ -312,6 +316,7 @@ class ECProcessorHandler(ProcessorHandler): ...@@ -312,6 +316,7 @@ class ECProcessorHandler(ProcessorHandler):
# Build encoder request # Build encoder request
encoder_request = self._create_encoder_request( encoder_request = self._create_encoder_request(
prompt=prompt,
mm_item=mm_item, mm_item=mm_item,
model=model, model=model,
request_id=item_request_id, request_id=item_request_id,
...@@ -421,6 +426,7 @@ class ECProcessorHandler(ProcessorHandler): ...@@ -421,6 +426,7 @@ class ECProcessorHandler(ProcessorHandler):
) )
try: try:
await self._encode_multimodal_items( await self._encode_multimodal_items(
prompt=prompt,
mm_items=mm_items, mm_items=mm_items,
model=raw_request.model, model=raw_request.model,
request_id=request_id, request_id=request_id,
...@@ -453,7 +459,9 @@ class ECProcessorHandler(ProcessorHandler): ...@@ -453,7 +459,9 @@ class ECProcessorHandler(ProcessorHandler):
engine_prompt=engine_prompt, engine_prompt=engine_prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
multimodal_input=multimodal_input, # ✓ Keep this so vLLM can generate mm_hash multimodal_inputs=[
MultiModalGroup(multimodal_input=multimodal_input)
], # ✓ Keep this so vLLM can generate mm_hash
) )
logger.debug( logger.debug(
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
import copy import copy
import logging import logging
import os
from collections import defaultdict
import safetensors
import torch import torch
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
...@@ -22,6 +25,8 @@ from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_m ...@@ -22,6 +25,8 @@ from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_m
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
class MultimodalDecodeWorkerHandler(BaseWorkerHandler): class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
"""Decode worker for disaggregated multimodal serving""" """Decode worker for disaggregated multimodal serving"""
...@@ -164,86 +169,118 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -164,86 +169,118 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request = vLLMMultimodalRequest.model_validate(request) request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
# ECConnector consumer mode: vLLM loads embeddings automatically from disk multi_modal_data = defaultdict(list)
# We need to pass multimodal_input so vLLM can generate mm_hash and look up cache for mi in request.multimodal_inputs:
if self.config.ec_consumer_mode: # ECConnector consumer mode: vLLM loads embeddings automatically from disk
logger.debug( # We need to pass multimodal_input so vLLM can generate mm_hash and look up cache
f"[{request.request_id}] ECConnector consumer mode: " if self.config.ec_consumer_mode:
f"vLLM will load embeddings from cache using mm_hash" logger.debug(
) f"[{request.request_id}] ECConnector consumer mode: "
# Use PIL image loading - vLLM will detect it's already in EC cache f"vLLM will load embeddings from cache using mm_hash"
# and load from disk instead of reprocessing )
if request.multimodal_input and request.multimodal_input.image_url: # Use PIL image loading - vLLM will detect it's already in EC cache
multi_modal_data = { # and load from disk instead of reprocessing
"image": await self.image_loader.load_image( if mi.multimodal_input.image_url:
request.multimodal_input.image_url multi_modal_data["image"].append(
await self.image_loader.load_image(
mi.multimodal_input.image_url
)
) )
} elif mi.multimodal_input.video_url:
elif request.multimodal_input and request.multimodal_input.video_url: # For video, load as image placeholder (vLLM will use EC cache)
# For video, load as image placeholder (vLLM will use EC cache) multi_modal_data["image"].append(
multi_modal_data = { await self.image_loader.load_image(
"image": await self.image_loader.load_image( request.multimodal_input.video_url
request.multimodal_input.video_url )
) )
} else:
else: raise ValueError(
raise ValueError( "ECConnector mode requires multimodal_input with image/video URL"
"ECConnector mode requires multimodal_input with image/video URL" )
) elif (
elif ( mi.multimodal_input.image_url is None
request.multimodal_input is not None and mi.multimodal_input.video_url is None
and request.multimodal_input.image_url is None ):
and request.multimodal_input.video_url is None # Process embeddings using the connector
): # Create a descriptor based on the embedding shape.
# Network transfer mode: receive embeddings via connector from encoder worker if TRANSFER_LOCAL:
# Create a descriptor based on the embedding shape. logger.info("PD: Loading local safetensors file")
embeddings = torch.empty( embeddings = safetensors.torch.load_file(mi.serialized_request)[
request.embeddings_shape, "ec_cache"
dtype=self.EMBEDDINGS_DTYPE, ]
device=self.EMBEDDINGS_DEVICE, else:
) embeddings = torch.empty(
descriptor = connect.Descriptor(embeddings) mi.embeddings_shape,
dtype=self.EMBEDDINGS_DTYPE,
device=self.EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
if descriptor is None: if descriptor is None:
raise RuntimeError( raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings" "Descriptor is None in PD worker - cannot process embeddings"
) )
read_op = await self._connector.begin_read( read_op = await self._connector.begin_read(
request.serialized_request, descriptor mi.serialized_request, descriptor
) )
await read_op.wait_for_completion() await read_op.wait_for_completion()
if "video" in self.config.model.lower(): if "video" in self.config.model.lower():
video_numpy = embeddings.numpy() video_numpy = embeddings.numpy()
multi_modal_data = construct_mm_data( mm_data = construct_mm_data(
self.config.model, self.config.model,
self.EMBEDDINGS_DTYPE, self.EMBEDDINGS_DTYPE,
video_numpy=video_numpy, video_numpy=video_numpy,
) )
multi_modal_data["video"].append(mm_data["video"])
else:
mm_data = construct_mm_data(
self.config.model,
self.EMBEDDINGS_DTYPE,
image_embeds=embeddings,
image_grid_thw=mi.image_grid_thw,
)
if isinstance(mm_data["image"], dict):
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
# Merging tensors
multi_modal_data["image"]["image_embeds"] = torch.cat(
(
multi_modal_data["image"]["image_embeds"],
mm_data["image"]["image_embeds"],
)
)
multi_modal_data["image"]["image_grid_thw"] = torch.cat(
(
multi_modal_data["image"]["image_grid_thw"],
mm_data["image"]["image_grid_thw"],
)
)
else:
logger.info(f"Get embedding of shape {mm_data['image'].shape}")
# [gluo FIXME] embedding with multiple images?
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
multi_modal_data["image"] = torch.cat(
(multi_modal_data["image"], mm_data["image"])
)
else: else:
multi_modal_data = construct_mm_data( # Use PIL image instead of image embeddings
self.config.model, multi_modal_data["image"].append(
self.EMBEDDINGS_DTYPE, await self.image_loader.load_image(mi.multimodal_input.image_url)
image_embeds=embeddings,
image_grid_thw=request.image_grid_thw,
) )
elif request.multimodal_input is not None:
# Native mode: Use PIL image instead of image embeddings
multi_modal_data = {
"image": await self.image_loader.load_image(
request.multimodal_input.image_url
)
}
else:
raise ValueError(
"Invalid request: multimodal_input is None but not in ec_consumer_mode"
)
# Clear multimodal_input fields if present (not needed for engine) # Remove the image features from the request as they are not required
if request.multimodal_input is not None: request.multimodal_inputs = None
request.multimodal_input.image_url = None
request.multimodal_input.video_url = None logger.info(f"Prepared multimodal data size: {len(multi_modal_data['image'])}")
request.serialized_request = None logger.info(f"{multi_modal_data}")
# Deepcopy the request to avoid modifying the original
# when we adjust sampling params for prefill
pd_request = copy.deepcopy(request) pd_request = copy.deepcopy(request)
# Do prefill and remote decode if enable_disagg is true # Do prefill and remote decode if enable_disagg is true
...@@ -311,6 +348,10 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -311,6 +348,10 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
logger.debug( logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}" f"Response kv_transfer_params: {response.kv_transfer_params}"
) )
logger.debug(
f"length of expanded prompt ids: {len(response.prompt_token_ids)}"
)
# logger.info(f"Response outputs: {response.outputs}")
yield MyRequestOutput( yield MyRequestOutput(
request_id=response.request_id, request_id=response.request_id,
prompt=response.prompt, prompt=response.prompt,
......
...@@ -8,6 +8,7 @@ from dynamo.vllm.multimodal_utils.chat_processor import ( ...@@ -8,6 +8,7 @@ from dynamo.vllm.multimodal_utils.chat_processor import (
) )
from dynamo.vllm.multimodal_utils.encode_utils import ( from dynamo.vllm.multimodal_utils.encode_utils import (
encode_image_embeddings, encode_image_embeddings,
get_embedding_hash,
get_encoder_components, get_encoder_components,
) )
from dynamo.vllm.multimodal_utils.http_client import get_http_client from dynamo.vllm.multimodal_utils.http_client import get_http_client
...@@ -18,9 +19,11 @@ from dynamo.vllm.multimodal_utils.model import ( ...@@ -18,9 +19,11 @@ from dynamo.vllm.multimodal_utils.model import (
load_vision_model, load_vision_model,
) )
from dynamo.vllm.multimodal_utils.protocol import ( from dynamo.vllm.multimodal_utils.protocol import (
MultiModalGroup,
MultiModalInput, MultiModalInput,
MultiModalRequest, MultiModalRequest,
MyRequestOutput, MyRequestOutput,
PatchedTokensPrompt,
VLLMNativeEncoderRequest, VLLMNativeEncoderRequest,
VLLMNativeEncoderResponse, VLLMNativeEncoderResponse,
vLLMMultimodalRequest, vLLMMultimodalRequest,
...@@ -38,6 +41,9 @@ __all__ = [ ...@@ -38,6 +41,9 @@ __all__ = [
"construct_mm_data", "construct_mm_data",
"load_vision_model", "load_vision_model",
"MultiModalInput", "MultiModalInput",
"MultiModalGroup",
"PatchedTokensPrompt",
"get_embedding_hash",
"MultiModalRequest", "MultiModalRequest",
"MyRequestOutput", "MyRequestOutput",
"vLLMMultimodalRequest", "vLLMMultimodalRequest",
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib
import json import json
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
...@@ -25,6 +26,18 @@ from .model import SupportedModels, is_model_supported, is_qwen_vl_model ...@@ -25,6 +26,18 @@ from .model import SupportedModels, is_model_supported, is_qwen_vl_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_embedding_hash(key: str) -> str:
"""
Generate a unique hash key for storing/retrieving image embeddings.
Args:
key: The base key string (e.g., image URL or identifier)
Returns:
A unique hash string for the given key.
"""
return hashlib.sha256(key.encode()).hexdigest()
def get_qwen_image_features( def get_qwen_image_features(
vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any] vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any]
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -29,6 +29,7 @@ class SupportedModels: ...@@ -29,6 +29,7 @@ class SupportedModels:
LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf" LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf"
QWEN_2_VL_2B = "Qwen/Qwen2-VL-2B-Instruct" QWEN_2_VL_2B = "Qwen/Qwen2-VL-2B-Instruct"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct" QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
QWEN_2_5_VL_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf" LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
...@@ -105,6 +106,7 @@ def is_model_supported(model_name: str, supported_model: str) -> bool: ...@@ -105,6 +106,7 @@ def is_model_supported(model_name: str, supported_model: str) -> bool:
QWEN_VL_MODELS = [ QWEN_VL_MODELS = [
SupportedModels.QWEN_2_VL_2B, SupportedModels.QWEN_2_VL_2B,
SupportedModels.QWEN_2_5_VL_7B, SupportedModels.QWEN_2_5_VL_7B,
SupportedModels.QWEN_2_5_VL_3B,
] ]
......
...@@ -140,20 +140,26 @@ class MultiModalInput(BaseModel): ...@@ -140,20 +140,26 @@ class MultiModalInput(BaseModel):
video_url: Optional[str] = None video_url: Optional[str] = None
class vLLMMultimodalRequest(vLLMGenerateRequest): class MultiModalGroup(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput) multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput)
image_grid_thw: Optional[List[Any]] = None image_grid_thw: Optional[List[Any]] = None
embeddings_shape: Optional[ embeddings_shape: Optional[
Union[Tuple[int, int, int], Tuple[int, int, int, int]] Union[Tuple[int, int, int], Tuple[int, int, int, int]]
] = None ] = None
serialized_request: Optional[connect.RdmaMetadata] = None serialized_request: Optional[connect.RdmaMetadata | str] = None
class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
multimodal_inputs: List[MultiModalGroup] = Field(default_factory=list)
class VLLMNativeEncoderRequest(BaseModel): class VLLMNativeEncoderRequest(BaseModel):
"""Request for vLLM-native encoder worker using ECConnector""" """Request for vLLM-native encoder worker using ECConnector"""
request_id: str request_id: str
prompt: str
multimodal_input: MultiModalInput multimodal_input: MultiModalInput
modality: Literal["image", "video", "audio"] modality: Literal["image", "video", "audio"]
batch_items: Optional[List[MultiModalInput]] = None # For future batch processing batch_items: Optional[List[MultiModalInput]] = None # For future batch processing
......
...@@ -16,9 +16,7 @@ set -e ...@@ -16,9 +16,7 @@ set -e
trap 'echo Cleaning up...; kill 0' EXIT trap 'echo Cleaning up...; kill 0' EXIT
# Default values # Default values
MODEL_NAME="llava-hf/llava-1.5-7b-hf" MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct"
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
PROVIDED_PROMPT_TEMPLATE=""
SINGLE_GPU=false SINGLE_GPU=false
# Parse command line arguments # Parse command line arguments
...@@ -28,10 +26,6 @@ while [[ $# -gt 0 ]]; do ...@@ -28,10 +26,6 @@ while [[ $# -gt 0 ]]; do
MODEL_NAME=$2 MODEL_NAME=$2
shift 2 shift 2
;; ;;
--prompt-template)
PROVIDED_PROMPT_TEMPLATE=$2
shift 2
;;
--single-gpu) --single-gpu)
SINGLE_GPU=true SINGLE_GPU=true
shift shift
...@@ -40,7 +34,6 @@ while [[ $# -gt 0 ]]; do ...@@ -40,7 +34,6 @@ while [[ $# -gt 0 ]]; do
echo "Usage: $0 [OPTIONS]" echo "Usage: $0 [OPTIONS]"
echo "Options:" echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)" echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates."
echo " --single-gpu Run both encode and PD workers on GPU 0 (for pre-merge CI)" echo " --single-gpu Run both encode and PD workers on GPU 0 (for pre-merge CI)"
echo " -h, --help Show this help message" echo " -h, --help Show this help message"
exit 0 exit 0
...@@ -53,22 +46,6 @@ while [[ $# -gt 0 ]]; do ...@@ -53,22 +46,6 @@ while [[ $# -gt 0 ]]; do
esac esac
done done
# Set PROMPT_TEMPLATE based on the MODEL_NAME
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
elif [[ "$MODEL_NAME" == "microsoft/Phi-3.5-vision-instruct" ]]; then
PROMPT_TEMPLATE="<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
elif [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]] || [[ "$MODEL_NAME" == "Qwen/Qwen2-VL-2B-Instruct" ]]; then
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
else
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
echo "Please provide a prompt template using --prompt-template option."
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
exit 1
fi
# Start frontend (HTTP endpoint) # Start frontend (HTTP endpoint)
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000) # dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
python -m dynamo.frontend & python -m dynamo.frontend &
...@@ -78,14 +55,14 @@ python -m dynamo.frontend & ...@@ -78,14 +55,14 @@ python -m dynamo.frontend &
# Multi-GPU mode: Each worker gets its own GPU, so use higher memory settings # Multi-GPU mode: Each worker gets its own GPU, so use higher memory settings
EXTRA_ARGS="" EXTRA_ARGS=""
if [[ "$SINGLE_GPU" == "true" ]]; then if [[ "$SINGLE_GPU" == "true" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.3 --max-model-len 3072 --enforce-eager" EXTRA_ARGS="--gpu-memory-utilization 0.5 --enforce-eager --max-model-len 30426"
else else
# Multi-GPU mode: standard memory settings # Multi-GPU mode: standard memory settings
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 4096" EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 34096"
fi fi
# Start processor (Python-based preprocessing, handles prompt templating) # Start processor (Python-based preprocessing, handles prompt templating)
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" & python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME &
# run E/P/D workers # run E/P/D workers
# Use single GPU (GPU 0) for pre-merge CI, otherwise use GPU 0 for encode and GPU 1 for PD # Use single GPU (GPU 0) for pre-merge CI, otherwise use GPU 0 for encode and GPU 1 for PD
......
...@@ -12,7 +12,7 @@ MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" ...@@ -12,7 +12,7 @@ MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
python -m dynamo.frontend & python -m dynamo.frontend &
# run processor # run processor
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME --mm-prompt-template "<|image|>\n<prompt>" & python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME &
# Llama 4 doesn't support image embedding input, so use encode+prefill worker # Llama 4 doesn't support image embedding input, so use encode+prefill worker
# that handles image encoding inline # that handles image encoding inline
python -m dynamo.vllm --multimodal-encode-prefill-worker --enable-multimodal --model $MODEL_NAME --tensor-parallel-size=8 --max-model-len=208960 --gpu-memory-utilization 0.80 & python -m dynamo.vllm --multimodal-encode-prefill-worker --enable-multimodal --model $MODEL_NAME --tensor-parallel-size=8 --max-model-len=208960 --gpu-memory-utilization 0.80 &
......
...@@ -6,8 +6,6 @@ trap 'echo Cleaning up...; kill 0' EXIT ...@@ -6,8 +6,6 @@ trap 'echo Cleaning up...; kill 0' EXIT
# Default values # Default values
MODEL_NAME="llava-hf/llava-1.5-7b-hf" MODEL_NAME="llava-hf/llava-1.5-7b-hf"
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
PROVIDED_PROMPT_TEMPLATE=""
# Parse command line arguments # Parse command line arguments
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
...@@ -16,10 +14,6 @@ while [[ $# -gt 0 ]]; do ...@@ -16,10 +14,6 @@ while [[ $# -gt 0 ]]; do
MODEL_NAME=$2 MODEL_NAME=$2
shift 2 shift 2
;; ;;
--prompt-template)
PROVIDED_PROMPT_TEMPLATE=$2
shift 2
;;
-h|--help) -h|--help)
echo "Usage: $0 [OPTIONS]" echo "Usage: $0 [OPTIONS]"
echo "" echo ""
...@@ -27,7 +21,6 @@ while [[ $# -gt 0 ]]; do ...@@ -27,7 +21,6 @@ while [[ $# -gt 0 ]]; do
echo "" echo ""
echo "Options:" echo "Options:"
echo " --model <model_name> Specify the VLM model to use (default: $MODEL_NAME)" echo " --model <model_name> Specify the VLM model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use"
echo " LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates" echo " LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates"
echo " -h, --help Show this help message" echo " -h, --help Show this help message"
echo "" echo ""
...@@ -46,27 +39,11 @@ while [[ $# -gt 0 ]]; do ...@@ -46,27 +39,11 @@ while [[ $# -gt 0 ]]; do
esac esac
done done
# Set PROMPT_TEMPLATE based on the MODEL_NAME
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
elif [[ "$MODEL_NAME" == "microsoft/Phi-3.5-vision-instruct" ]]; then
PROMPT_TEMPLATE="<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
elif [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
else
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
echo "Please provide a prompt template using --prompt-template option."
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
exit 1
fi
echo "==================================================" echo "=================================================="
echo "Disaggregated Multimodal Serving" echo "Disaggregated Multimodal Serving"
echo "==================================================" echo "=================================================="
echo "Model: $MODEL_NAME" echo "Model: $MODEL_NAME"
echo "Prompt Template: $PROMPT_TEMPLATE"
echo "==================================================" echo "=================================================="
...@@ -77,7 +54,7 @@ python -m dynamo.frontend & ...@@ -77,7 +54,7 @@ python -m dynamo.frontend &
# Start processor # Start processor
echo "Starting processor..." echo "Starting processor..."
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" & python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME &
EXTRA_ARGS="" EXTRA_ARGS=""
......
...@@ -49,7 +49,7 @@ if [[ $HEAD_NODE -eq 1 ]]; then ...@@ -49,7 +49,7 @@ if [[ $HEAD_NODE -eq 1 ]]; then
python -m dynamo.frontend & python -m dynamo.frontend &
# run processor # run processor
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME --mm-prompt-template "<|image|>\n<prompt>" & python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME &
# Llama 4 doesn't support image embedding input, so the prefill worker will also # Llama 4 doesn't support image embedding input, so the prefill worker will also
# handle image encoding inline. # handle image encoding inline.
......
...@@ -296,7 +296,9 @@ vllm_configs = { ...@@ -296,7 +296,9 @@ vllm_configs = {
}, },
], ],
repeat_count=1, repeat_count=1,
expected_response=["purple"], # With proper prompt templating, the model actually only returns "green",
# verified behavior with native vLLM.
expected_response=["green"],
temperature=0.0, temperature=0.0,
max_tokens=100, max_tokens=100,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment