Unverified Commit 52090e2e authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: refactor vllm multimodal example (#3634)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 7d78fdad
......@@ -55,18 +55,18 @@ fi
# run ingress
python -m dynamo.frontend --http-port=8000 &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# To make Qwen2.5-VL fit in A100 40GB, set the following extra arguments
EXTRA_ARGS=""
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
fi
# run processor
python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
# run E/P/D workers
CUDA_VISIBLE_DEVICES=0 python3 components/encode_worker.py --model $MODEL_NAME &
CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill $EXTRA_ARGS &
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --model $MODEL_NAME &
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --model $MODEL_NAME $EXTRA_ARGS &
# Wait for all background processes to complete
wait
......@@ -65,6 +65,11 @@ class Config:
tool_call_parser: Optional[str] = None
reasoning_parser: Optional[str] = None
# multimodal options
multimodal_processor: bool = False
multimodal_encode_worker: bool = False
multimodal_worker: bool = False
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
# dump config to file
dump_config_to: Optional[str] = None
......@@ -137,6 +142,34 @@ def parse_args() -> Config:
default=None,
help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.",
)
parser.add_argument(
"--multimodal-processor",
action="store_true",
help="Run as multimodal processor component for handling multimodal requests",
)
parser.add_argument(
"--multimodal-encode-worker",
action="store_true",
help="Run as multimodal encode worker component for processing images/videos",
)
parser.add_argument(
"--multimodal-worker",
action="store_true",
help="Run as multimodal worker component for LLM inference with multimodal data",
)
parser.add_argument(
"--mm-prompt-template",
type=str,
default="USER: <image>\n<prompt> ASSISTANT:",
help=(
"Different multi-modal models expect the prompt to contain different special media prompts. "
"The processor will use this argument to construct the final prompt. "
"User prompt will replace '<prompt>' in the provided template. "
"For example, if the user prompt is 'please describe the image' and the prompt template is "
"'USER: <image> <prompt> ASSISTANT:', the resulting prompt is "
"'USER: <image> please describe the image ASSISTANT:'."
),
)
add_config_dump_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
......@@ -161,8 +194,35 @@ def parse_args() -> Config:
config.served_model_name = None
config.namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
config.component = "prefill" if args.is_prefill_worker else "backend"
# Check multimodal role exclusivity
mm_flags = (
int(bool(args.multimodal_processor))
+ int(bool(args.multimodal_encode_worker))
+ int(bool(args.multimodal_worker))
)
if mm_flags > 1:
raise ValueError(
"Use only one of --multimodal-processor, --multimodal-encode-worker, or --multimodal-worker"
)
# Set component and endpoint based on worker type
if args.multimodal_processor:
config.component = "processor"
config.endpoint = "generate"
elif args.multimodal_encode_worker:
config.component = "encoder"
config.endpoint = "generate"
elif args.multimodal_worker and args.is_prefill_worker:
config.component = "prefill"
config.endpoint = "generate"
elif args.is_prefill_worker:
config.component = "prefill"
config.endpoint = "generate"
else:
config.component = "backend"
config.endpoint = "generate"
config.engine_args = engine_args
config.is_prefill_worker = args.is_prefill_worker
config.is_decode_worker = args.is_decode_worker
......@@ -173,6 +233,10 @@ def parse_args() -> Config:
config.tool_call_parser = args.dyn_tool_call_parser
config.reasoning_parser = args.dyn_reasoning_parser
config.custom_jinja_template = args.custom_jinja_template
config.multimodal_processor = args.multimodal_processor
config.multimodal_encode_worker = args.multimodal_encode_worker
config.multimodal_worker = args.multimodal_worker
config.mm_prompt_template = args.mm_prompt_template
# Validate custom Jinja template file exists if provided
if config.custom_jinja_template is not None:
......
......@@ -26,6 +26,11 @@ from dynamo.llm import (
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.multimodal_handlers import (
EncodeWorkerHandler,
MultimodalPDWorkerHandler,
ProcessorHandler,
)
from .args import ENABLE_LMCACHE, Config, configure_ports, overwrite_args, parse_args
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
......@@ -92,7 +97,17 @@ async def worker(runtime: DistributedRuntime):
if not os.path.exists(config.model):
config.model = config.engine_args.model = await fetch_llm(config.model)
if config.is_prefill_worker:
# Route to appropriate initialization based on config flags
if config.multimodal_processor:
await init_multimodal_processor(runtime, config)
logger.debug("init_multimodal_processor completed")
elif config.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config)
logger.debug("init_multimodal_encode_worker completed")
elif config.multimodal_worker:
await init_multimodal_worker(runtime, config)
logger.debug("init_multimodal_worker completed")
elif config.is_prefill_worker:
await init_prefill(runtime, config)
logger.debug("init_prefill completed")
else:
......@@ -430,6 +445,147 @@ def get_engine_cache_info(engine: AsyncLLM):
raise
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal processor component"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
generate_endpoint = component.endpoint(config.endpoint)
# Get encode worker client
encode_worker_client = (
await runtime.namespace(config.namespace)
.component("encoder")
.endpoint("generate")
.client()
)
# Get prompt template from args (must be passed via environment or command line)
mm_prompt_template = config.mm_prompt_template
handler = ProcessorHandler(
config.engine_args,
encode_worker_client,
mm_prompt_template,
)
logger.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
# Register the endpoint as entrypoint to a model
await register_llm(
ModelInput.Text, # Custom processor is used and this type bypasses SDK processor
ModelType.Chat,
generate_endpoint,
config.model,
config.served_model_name,
kv_cache_block_size=config.engine_args.block_size,
)
logger.info("Starting to serve the processor endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal encode worker component"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
generate_endpoint = component.endpoint(config.endpoint)
# Get PD worker client
# In multimodal mode, the PD worker always registers as "backend"
# (even in disaggregated mode with prefill/decode split, we still connect to "backend")
pd_worker_client = (
await runtime.namespace(config.namespace)
.component("backend")
.endpoint("generate")
.client()
)
handler = EncodeWorkerHandler(
config.engine_args,
pd_worker_client,
)
await handler.async_init(runtime)
logger.info("Waiting for PD Worker Instances ...")
await pd_worker_client.wait_for_instances()
logger.info("Starting to serve the encode worker endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal worker component for aggregated or disaggregated mode"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
# TODO: Support Disaggregated mode separately
client = (
await runtime.namespace(config.namespace)
.component("backend")
.endpoint("generate")
.client()
)
handler = MultimodalPDWorkerHandler(
runtime, component, engine_client, config, client
)
await handler.async_init(runtime)
# Set up KV event publisher for prefix caching if enabled
kv_publisher = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
)
if kv_publisher:
handler.kv_publisher = kv_publisher
metrics_labels = [("model", config.model)]
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=metrics_labels
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=metrics_labels
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
def main():
uvloop.run(worker())
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.vllm.multimodal_handlers.encode_worker_handler import EncodeWorkerHandler
from dynamo.vllm.multimodal_handlers.processor_handler import ProcessorHandler
from dynamo.vllm.multimodal_handlers.worker_handler import (
MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler,
)
__all__ = [
"EncodeWorkerHandler",
"ProcessorHandler",
"MultimodalPDWorkerHandler",
"MultimodalDecodeWorkerHandler",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import AsyncIterator
from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
import dynamo.nixl_connect as connect
from dynamo.runtime import Client, DistributedRuntime
from ..multimodal_utils import (
ImageLoader,
MyRequestOutput,
encode_image_embeddings,
get_encoder_components,
load_vision_model,
vLLMMultimodalRequest,
)
logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
class EncodeWorkerHandler:
def __init__(
self,
engine_args: AsyncEngineArgs,
pd_worker_client: Client,
) -> None:
self.pd_worker_client = pd_worker_client
self.engine_args = engine_args
self.model = self.engine_args.model
self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM)
self.image_processor = AutoImageProcessor.from_pretrained(
self.model, trust_remote_code=True
)
self.vision_model = load_vision_model(self.model)
self.min_workers = 1
# Get encoder components for the model
self.vision_encoder, self.projector = get_encoder_components(
self.model, self.vision_model
)
self._connector = None
def cleanup(self):
pass
async def async_init(self, runtime: DistributedRuntime):
"""Initialize the connector for RDMA transfers"""
logger.info("Encode worker startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Encode worker startup completed.")
async def generate(
self, request: vLLMMultimodalRequest, context
) -> AsyncIterator[str]:
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
request_id = request.request_id
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
try:
if not request.multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.")
image = await self.image_loader.load_image(
request.multimodal_input.image_url
)
logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
# Encode the image embeddings using model-specific encoder
embeddings = encode_image_embeddings(
model_name=self.model,
image_embeds=image_embeds,
vision_encoder=self.vision_encoder,
projector=self.projector,
)
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
# Move embeddings to CPU for NIXL transfer to avoid UCX/InfiniBand issues
embeddings_cpu = embeddings.cpu()
request.image_grid_thw = image_grid_thw
request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings_cpu)
with self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None
logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json(), context=context
)
await readable.wait_for_completion()
async for response in response_generator:
output = MyRequestOutput.model_validate_json(response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
).model_dump_json()
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
raise
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import uuid
from enum import Enum
from typing import AsyncIterator, Union
from transformers import AutoTokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.runtime import Client
from ..multimodal_utils import (
ChatProcessor,
CompletionsProcessor,
MultiModalInput,
MultiModalRequest,
MyRequestOutput,
ProcessMixIn,
vLLMMultimodalRequest,
)
logger = logging.getLogger(__name__)
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class ProcessorHandler(ProcessMixIn):
"""
vLLM pre and post processing for multimodal requests
"""
def __init__(
self,
engine_args: AsyncEngineArgs,
encode_worker_client: Client,
prompt_template: str,
):
self.encode_worker_client = encode_worker_client
self.prompt_template = prompt_template
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
def cleanup(self):
pass
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
multimodal_input: MultiModalInput,
request_type: RequestType,
context,
):
request_id = str(uuid.uuid4().hex)
logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_request = vLLMMultimodalRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
multimodal_input=multimodal_input,
)
# model_dump_json() serializes the request to JSON string
# This API could accept Pydantic class, but SamplingParams
# in vLLMMultimodalRequest is not a Pydantic class and will
# cause TypeError: unsupported type SamplingParams
response_generator = await self.encode_worker_client.round_robin(
worker_request.model_dump_json()
)
output = self._generate_responses(response_generator, request_type)
# Stream the processed responses
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
request_type: RequestType,
):
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
# The generate endpoint will be used by the frontend to handle incoming requests.
async def generate(self, raw_request: MultiModalRequest, context):
logger.debug(f"Got raw request: {raw_request}")
if not isinstance(raw_request, MultiModalRequest):
# If the request is not MultiModalRequest, convert it to MultiModalRequest
raw_request = MultiModalRequest.model_validate(raw_request)
# Ensure the configured template includes the placeholder
template = self.prompt_template
if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text
try:
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")
prompt = template.replace("<prompt>", user_text)
msg = {
"role": "user",
"content": prompt,
}
# Set stream=True - the http frontend will handle aggregation of
# streamed chunks into a single http response, or stream them
# back as SSE responses based on the stream flag in the request.
chat_request = ChatCompletionRequest(
model=raw_request.model,
messages=[msg],
stream=True,
max_tokens=raw_request.max_tokens,
temperature=raw_request.temperature,
request_id=str(uuid.uuid4()),
)
multimodal_input = MultiModalInput()
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
multimodal_input.image_url = item.image_url.url
elif item.type == "video_url":
if multimodal_input.image_url is not None:
raise ValueError("Cannot provide both image and video URLs")
multimodal_input.video_url = item.video_url.url
if multimodal_input.image_url is None and multimodal_input.video_url is None:
raise ValueError("Either image URL or video URL is required")
async for response in self._generate(
chat_request, multimodal_input, RequestType.CHAT, context
):
logger.debug(
f"Generated response type {type(response)}, content: {response}"
)
# reconstructing back the OpenAI chat response as dynamo egress expects it
if response.startswith("data: [DONE]"):
break
response = json.loads(response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import copy
import logging
import torch
from vllm.inputs.data import TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM
import dynamo.nixl_connect as connect
from dynamo.runtime import Client, Component, DistributedRuntime
from ..handlers import BaseWorkerHandler
from ..multimodal_utils import (
ImageLoader,
MyRequestOutput,
construct_mm_data,
vLLMMultimodalRequest,
)
logger = logging.getLogger(__name__)
class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
"""Decode worker for disaggregated multimodal serving"""
def __init__(
self,
runtime,
component,
engine_client,
config,
):
# Get default_sampling_params from config
default_sampling_params = (
config.engine_args.create_model_config().get_diff_sampling_param()
)
# Call BaseWorkerHandler.__init__ with proper parameters
super().__init__(runtime, component, engine_client, default_sampling_params)
self.config = config
self.enable_disagg = config.is_prefill_worker
async def async_init(self, runtime: DistributedRuntime):
"""Async initialization - connector needs async setup"""
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Multimodal Decode Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context):
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")
# Decode worker doesn't process embeddings, so we pass None or empty tensor
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
),
sampling_params=request.sampling_params,
request_id=request.request_id,
)
async for response in gen:
logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
class MultimodalPDWorkerHandler(BaseWorkerHandler):
"""Prefill/Decode or Prefill-only worker for multimodal serving"""
def __init__(
self,
runtime,
component: Component,
engine_client: AsyncLLM,
config,
decode_worker_client: Client = None,
):
# Get default_sampling_params from config
default_sampling_params = (
config.engine_args.create_model_config().get_diff_sampling_param()
)
# Call BaseWorkerHandler.__init__ with proper parameters
super().__init__(runtime, component, engine_client, default_sampling_params)
self.config = config
self.decode_worker_client = decode_worker_client
self.enable_disagg = config.is_prefill_worker
# Initialize multimodal-specific components
logger.info("Multimodal PD Worker startup started.")
if "video" in self.config.model.lower():
self.EMBEDDINGS_DTYPE = torch.uint8
else:
self.EMBEDDINGS_DTYPE = torch.float16
self.EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init
self._connector = None # Will be initialized in async_init
self.image_loader = ImageLoader()
logger.info("Multimodal PD Worker has been initialized")
async def async_init(self, runtime: DistributedRuntime):
"""Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Multimodal PD Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context):
logger.debug(f"Got raw request: {request}")
if type(request) is not vLLMMultimodalRequest:
if type(request) is str:
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
if (
request.multimodal_input.image_url is None
and request.multimodal_input.video_url is None
):
# Process embeddings using the connector
# Create a descriptor based on the embedding shape.
embeddings = torch.empty(
request.embeddings_shape,
dtype=self.EMBEDDINGS_DTYPE,
device=self.EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op = await self._connector.begin_read(
request.serialized_request, descriptor
)
await read_op.wait_for_completion()
if "video" in self.config.model.lower():
video_numpy = embeddings.numpy()
multi_modal_data = construct_mm_data(
self.config.model,
self.EMBEDDINGS_DTYPE,
video_numpy=video_numpy,
)
else:
multi_modal_data = construct_mm_data(
self.config.model,
self.EMBEDDINGS_DTYPE,
image_embeds=embeddings,
image_grid_thw=request.image_grid_thw,
)
else:
# Use PIL image instead of image embeddings
multi_modal_data = {
"image": await self.image_loader.load_image(
request.multimodal_input.image_url
)
}
# Remove the image features from the request as they are not required
request.multimodal_input.image_url = None
request.multimodal_input.video_url = None
request.serialized_request = None
pd_request = copy.deepcopy(request)
# Do prefill and remote decode if enable_disagg is true
if self.enable_disagg and self.decode_worker_client:
extra_args = pd_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
pd_request.sampling_params.extra_args = extra_args
pd_request.sampling_params.max_tokens = 1
pd_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", pd_request)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"],
multi_modal_data=multi_modal_data,
),
sampling_params=pd_request.sampling_params,
request_id=pd_request.request_id,
)
if self.enable_disagg and self.decode_worker_client:
decode_request = copy.deepcopy(request)
async for prefill_response in gen:
# Update the prompt token id in the decode request to the one
# in response, which has image templated filled in. So that
# the decode worker will fetch correct amount of KV blocks.
decode_request.engine_prompt[
"prompt_token_ids"
] = prefill_response.prompt_token_ids
logger.debug(
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
)
extra_args = decode_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params
extra_args.pop("serialized_request", None)
decode_request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", decode_request)
async for (
decode_response
) in await self.decode_worker_client.round_robin(
decode_request.model_dump_json()
):
output = MyRequestOutput.model_validate_json(decode_response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
kv_transfer_params=output.kv_transfer_params,
).model_dump_json()
else:
async for response in gen:
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.vllm.multimodal_utils.chat_processor import (
ChatProcessor,
CompletionsProcessor,
ProcessMixIn,
)
from dynamo.vllm.multimodal_utils.encode_utils import (
encode_image_embeddings,
get_encoder_components,
)
from dynamo.vllm.multimodal_utils.http_client import get_http_client
from dynamo.vllm.multimodal_utils.image_loader import ImageLoader
from dynamo.vllm.multimodal_utils.model import (
SupportedModels,
construct_mm_data,
load_vision_model,
)
from dynamo.vllm.multimodal_utils.protocol import (
MultiModalInput,
MultiModalRequest,
MyRequestOutput,
vLLMMultimodalRequest,
)
__all__ = [
"ChatProcessor",
"CompletionsProcessor",
"ProcessMixIn",
"encode_image_embeddings",
"get_encoder_components",
"get_http_client",
"ImageLoader",
"SupportedModels",
"construct_mm_data",
"load_vision_model",
"MultiModalInput",
"MultiModalRequest",
"MyRequestOutput",
"vLLMMultimodalRequest",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
# TODO: Revisit this later when adding multi-modal support for the frontend.
# If no chat template is provided and tokenizer doesn't have one,
# use a simple format that just concatenates messages
if not request.chat_template and not self.tokenizer.chat_template:
chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% endif %}{% endfor %}Assistant:"
else:
chat_template = request.chat_template or self.tokenizer.chat_template
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if request.stream:
# Handle streaming response
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
yield raw_response
break
# Parse the response
response = json.loads(raw_response.lstrip("data: "))
# Process delta content to extract only new text
if "choices" in response and len(response["choices"]) > 0:
if "delta" in response["choices"][0]:
content = response["choices"][0]["delta"].get("content", "")
if content:
# Extract only the new part from the full content
new_content = content[num_output_text_so_far:]
response["choices"][0]["delta"]["content"] = new_content
num_output_text_so_far = len(content)
# Yield the processed response
yield f"data: {json.dumps(response)}\n\n"
else:
# Handle non-streaming response
# Collect all chunks into a single response
full_response = None
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
if full_response is None:
# Initialize the full response structure
full_response = {
"id": response.get("id", ""),
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": response.get("index", 0),
"message": {"role": "assistant", "content": ""},
"finish_reason": None,
}
],
}
# Concatenate content if it exists. Each delta contains the full text so far.
if "choices" in response and len(response["choices"]) > 0:
if "delta" in response["choices"][0]:
content = response["choices"][0]["delta"].get("content", "")
if content:
# Extract only the new part from the full content
new_content = content[num_output_text_so_far:]
full_response["choices"][0]["message"][
"content"
] += new_content
num_output_text_so_far = len(content)
# Update finish reason if present
if "finish_reason" in response["choices"][0]:
full_response["choices"][0]["finish_reason"] = response[
"choices"
][0]["finish_reason"]
if full_response is not None:
yield json.dumps(full_response)
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Optional
import torch
from .model import SupportedModels, is_model_supported
logger = logging.getLogger(__name__)
def get_qwen_image_features(
vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any]
) -> torch.Tensor:
"""
Extract image features using Qwen-style vision encoder.
Args:
vision_encoder: The vision encoder model
image_embeds: Dictionary containing pixel values and grid information
Returns:
Processed image features tensor
Raises:
ValueError: If grid_thw is not provided for Qwen model
"""
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
grid_thw = image_embeds.get("image_grid_thw", None)
if grid_thw is not None:
grid_thw = grid_thw.to(vision_encoder.device)
logger.debug(f"Qwen grid_thw shape: {grid_thw.shape}")
else:
raise ValueError("grid_thw is not provided")
return (
vision_encoder.get_image_features(pixel_values, grid_thw) # type: ignore
if grid_thw is not None
else vision_encoder.get_image_features(pixel_values) # type: ignore
)
def encode_image_embeddings(
model_name: str,
image_embeds: Dict[str, Any],
vision_encoder: torch.nn.Module,
projector: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
"""
Encode image embeddings using the appropriate model-specific encoder.
Args:
model_name: The model identifier
image_embeds: Dictionary containing processed image data
vision_encoder: The vision encoder module
projector: The multimodal projector (required for LLaVA-style models)
Returns:
Encoded embeddings tensor with normalized shape
Raises:
ValueError: If projector is missing for LLaVA models
NotImplementedError: If model is not supported
"""
with torch.no_grad():
# Route through the correct encoder based on model
if is_model_supported(model_name, SupportedModels.LLAVA_1_5_7B):
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
vision_outputs = vision_encoder(pixel_values)
if projector is None:
raise ValueError(f"Projector not found for LLaVA model: {model_name}")
embeddings = projector(vision_outputs.last_hidden_state)
elif is_model_supported(model_name, SupportedModels.QWEN_2_5_VL_7B):
embeddings = get_qwen_image_features(vision_encoder, image_embeds)
else:
raise NotImplementedError(f"Model not supported: {model_name}")
# Normalize output shape
if isinstance(embeddings, (tuple, list)):
embeddings = embeddings[0]
embeddings = embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings
return embeddings
def get_encoder_components(
model_name: str, vision_model: torch.nn.Module
) -> tuple[Any, Optional[Any]]:
"""
Get the appropriate vision encoder and projector components for a given model.
Args:
model_name: The model identifier
vision_model: The loaded vision model
Returns:
Tuple of (vision_encoder, projector) where types depend on the model
Raises:
NotImplementedError: If model is not supported
"""
if is_model_supported(model_name, SupportedModels.LLAVA_1_5_7B):
vision_encoder = vision_model.vision_tower
projector = getattr(vision_model, "multi_modal_projector", None)
return vision_encoder, projector
elif is_model_supported(model_name, SupportedModels.QWEN_2_5_VL_7B):
vision_encoder = vision_model
projector = None
return vision_encoder, projector
else:
raise NotImplementedError(f"Model not supported: {model_name}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
import httpx
logger = logging.getLogger(__name__)
# Global HTTP client instance
_global_http_client: Optional[httpx.AsyncClient] = None
def get_http_client(timeout: float = 60.0) -> httpx.AsyncClient:
"""
Get or create a shared HTTP client instance.
Args:
timeout: Timeout for HTTP requests
Returns:
Shared HTTP client instance
"""
global _global_http_client
if _global_http_client is None or _global_http_client.is_closed:
_global_http_client = httpx.AsyncClient(
timeout=timeout,
follow_redirects=True,
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100),
)
logger.info(f"Shared HTTP client initialized with timeout={timeout}s")
return _global_http_client
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import binascii
import logging
from io import BytesIO
from urllib.parse import urlparse
import httpx
from PIL import Image
from .http_client import get_http_client
logger = logging.getLogger(__name__)
class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
def __init__(
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0
):
self._http_timeout = http_timeout
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url)
# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if image_url_lower in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
return self._image_cache[image_url_lower]
try:
if parsed_url.scheme == "data":
# Parse data URL format: data:[<media type>][;base64],<data>
if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1)
if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded")
try:
image_bytes = base64.b64decode(data)
image_data = BytesIO(image_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
http_client = get_http_client(self._http_timeout)
response = await http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
# PIL is sync, so offload to a thread to avoid blocking the event loop
image = await asyncio.to_thread(Image.open, image_data)
# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = await self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
await self._cache_queue.put(image_url_lower)
return image_converted
except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
raise
except Exception as e:
logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from transformers import AutoModel
logger = logging.getLogger(__name__)
class SupportedModels:
"""Supported multimodal model identifiers"""
LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
def normalize_model_name(model_name: str) -> str:
"""
Extract and normalize model name from various formats including HuggingFace cache paths.
Args:
model_name: Model identifier which can be:
- A simple model name: "Qwen/Qwen2.5-VL-7B-Instruct"
- A HuggingFace cache path: "/root/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-7B-Instruct/..."
- A local path to a model directory
Returns:
Normalized model name in the format "organization/model-name"
Examples:
>>> normalize_model_name("Qwen/Qwen2.5-VL-7B-Instruct")
"Qwen/Qwen2.5-VL-7B-Instruct"
>>> normalize_model_name("/root/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-7B-Instruct/snapshots/...")
"Qwen/Qwen2.5-VL-7B-Instruct"
"""
# If it's already a simple model name (org/model format), return as-is
if "/" in model_name and not model_name.startswith("/"):
return model_name
# Handle HuggingFace cache paths
if "models--" in model_name:
# Extract from cache path format: models--ORG--MODEL-NAME
# Split on "models--" then on "--" to handle dashes in org/model names
parts_after_models = model_name.split("models--", 1)
if len(parts_after_models) > 1:
# Split the remaining part on "--" and take the last two segments
segments = parts_after_models[1].split("--")
if len(segments) >= 2:
# Take all segments except the last as org (rejoined with dashes)
# and the last segment (before any slash) as model name
org_segments = segments[:-1]
model_segment = segments[-1].split("/")[
0
] # Remove any path after model name
org = "--".join(org_segments) # Rejoin org parts with dashes
model = model_segment
return f"{org}/{model}"
# Handle local directory paths - extract the last directory name
path = Path(model_name)
if path.exists() and path.is_dir():
return path.name
# If no pattern matches, return the original name
return model_name
def is_model_supported(model_name: str, supported_model: str) -> bool:
"""
Check if a model name matches a supported model, handling various naming formats.
Args:
model_name: The model name to check (may be path, cache name, etc.)
supported_model: The supported model identifier
Returns:
True if the model is supported, False otherwise
"""
normalized_name = normalize_model_name(model_name).lower()
normalized_supported = normalize_model_name(supported_model).lower()
return normalized_name == normalized_supported
def load_vision_model(model_id: str) -> torch.nn.Module:
"""
Load a vision model from a HuggingFace model ID.
"""
model = AutoModel.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True
)
return model
def construct_mm_data(
model: str,
embeddings_dtype: torch.dtype,
image_embeds: Optional[torch.Tensor] = None,
video_numpy: Optional[Any] = None,
image_grid_thw: Optional[List[Any]] = None,
) -> Dict[str, Any]:
"""Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings"""
# Handle video models
if is_model_supported(model, SupportedModels.LLAVA_NEXT_VIDEO_7B):
if video_numpy is None:
raise ValueError("No video frames provided.")
return {"video": video_numpy}
# Handle image models - validate image embeddings first
if image_embeds is None:
raise ValueError("No image embeddings provided.")
image_embeds = image_embeds.to(embeddings_dtype)
# Model-specific image handling
if is_model_supported(model, SupportedModels.QWEN_2_5_VL_7B):
return _construct_qwen_image_data(image_embeds, image_grid_thw)
else:
# Default image handling for other models (e.g., LLAVA_1_5_7B)
return {"image": image_embeds}
def _construct_qwen_image_data(
image_embeds: torch.Tensor, image_grid_thw: Optional[List[Any]]
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Construct image data specifically for Qwen models."""
if image_grid_thw is None or len(image_grid_thw) == 0:
raise ValueError("No image grid provided for Qwen model.")
grid_thw_tensor = torch.tensor(image_grid_thw)
return {
"image": {
"image_embeds": image_embeds.squeeze(0),
"image_grid_thw": grid_thw_tensor,
}
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, List, Literal, Optional, Tuple, Union
import msgspec
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
from vllm.logprobs import PromptLogprobs
from vllm.multimodal.inputs import MultiModalUUIDDict # noqa: F401
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
import dynamo.nixl_connect as connect
class Request(BaseModel):
prompt: str
sampling_params: dict
class Tokens(BaseModel):
tokens: list[int]
class PrefillRequest(Request):
request_id: str
class Response(BaseModel):
text: str
class PrefillResponse(BaseModel):
prefilled: bool
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams type to add a dummy core schema so pydantic can validate it
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
prefix_hit_rate: Optional[float] = 0.0
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
model_config = ConfigDict(
arbitrary_types_allowed=True,
json_encoders={SamplingParams: lambda v: json.loads(msgspec.json.encode(v))},
)
class TextContent(BaseModel):
type: Literal["text"]
text: str
class ImageURLDetail(BaseModel):
url: str
class ImageContent(BaseModel):
type: Literal["image_url"]
image_url: ImageURLDetail
class VideoURLDetail(BaseModel):
url: str
class VideoContent(BaseModel):
type: Literal["video_url"]
video_url: VideoURLDetail
MessageContent = Union[TextContent, ImageContent, VideoContent]
class ChatMessage(BaseModel):
role: Literal["user", "system", "assistant"]
content: List[MessageContent]
class MultiModalRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model: str
messages: List[ChatMessage]
max_tokens: Optional[int] = None
temperature: Optional[float] = None
stream: Optional[bool] = True
class MultiModalInput(BaseModel):
image_url: Optional[str] = None
video_url: Optional[str] = None
class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput)
image_grid_thw: Optional[List[Any]] = None
embeddings_shape: Optional[
Union[Tuple[int, int, int], Tuple[int, int, int, int]]
] = None
serialized_request: Optional[connect.RdmaMetadata] = None
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
kv_transfer_params: Optional[dict[str, Any]] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
......@@ -152,6 +152,8 @@ addopts = [
"--ignore-glob=docs/*",
"--ignore-glob=components/src/dynamo/sglang/request_handlers/*",
"--ignore-glob=components/src/dynamo/sglang/multimodal_utils/*",
"--ignore-glob=components/src/dynamo/vllm/multimodal_utils/*",
"--ignore-glob=components/src/dynamo/vllm/multimodal_handlers/*",
"--ignore-glob=components/backends/sglang/slurm_jobs/*",
# FIXME: Get relative/generic blob paths to work here
]
......
......@@ -106,8 +106,8 @@ vllm_configs = {
),
"multimodal_agg_llava": VLLMConfig(
name="multimodal_agg_llava",
directory=os.path.join(WORKSPACE_DIR, "examples/multimodal"),
script_name="agg.sh",
directory=vllm_dir,
script_name="agg_multimodal.sh",
marks=[pytest.mark.gpu_2],
model="llava-hf/llava-1.5-7b-hf",
script_args=["--model", "llava-hf/llava-1.5-7b-hf"],
......@@ -130,8 +130,8 @@ vllm_configs = {
),
"multimodal_agg_qwen": VLLMConfig(
name="multimodal_agg_qwen",
directory=os.path.join(WORKSPACE_DIR, "examples/multimodal"),
script_name="agg.sh",
directory=vllm_dir,
script_name="agg_multimodal.sh",
marks=[pytest.mark.gpu_2],
model="Qwen/Qwen2.5-VL-7B-Instruct",
delayed_start=0,
......@@ -153,6 +153,7 @@ vllm_configs = {
)
],
),
# TODO: Update this test case when we have video multimodal support in vllm official components
"multimodal_video_agg": VLLMConfig(
name="multimodal_video_agg",
directory=os.path.join(WORKSPACE_DIR, "examples/multimodal"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment