"lib/bindings/vscode:/vscode.git/clone" did not exist on "481bf395ce8221ee5efdd940ef0ead8b5466a6f5"
Unverified Commit 8a098a66 authored by Michal Guzek's avatar Michal Guzek Committed by GitHub
Browse files

feat: TRT-LLM multimodal preprocessor with backend media decoding (#5910)


Signed-off-by: default avatarMichal Guzek <mguzek@nvidia.com>
parent 9bdc8b73
......@@ -11,9 +11,11 @@ from dynamo.common.multimodal.embedding_transfer import (
NixlPersistentEmbeddingSender,
TransferRequest,
)
from dynamo.common.multimodal.image_loader import ImageLoader
__all__ = [
"AsyncEncoderCache",
"ImageLoader",
"NixlPersistentEmbeddingReceiver",
"NixlPersistentEmbeddingSender",
"TransferRequest",
......
......@@ -18,15 +18,23 @@ import base64
import binascii
import logging
from io import BytesIO
from typing import Any, Dict, Final, List, Optional
from urllib.parse import urlparse
import httpx
from PIL import Image
import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from .http_client import get_http_client
logger = logging.getLogger(__name__)
# Constants for multimodal data variants
URL_VARIANT_KEY: Final = "Url"
DECODED_VARIANT_KEY: Final = "Decoded"
class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
......@@ -108,3 +116,77 @@ class ImageLoader:
except Exception as e:
logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}")
async def load_image_batch(
self,
image_mm_items: List[Dict[str, Any]],
enable_frontend_decoding: bool = False,
nixl_connector: Optional["nixl_connect.Connector"] = None,
) -> List[Any]:
"""
Load a batch of images from multimodal data items.
Supports two paths:
1. Url variant: Download and decode image from URL (default)
2. Decoded variant: Read pre-decoded image via NIXL RDMA (requires enable_frontend_decoding=True)
Args:
image_mm_items: List of multimodal data items for images
enable_frontend_decoding: If True, enables NIXL RDMA for decoded images
nixl_connector: NIXL connector for frontend decoding (required if enable_frontend_decoding=True)
Returns:
List of loaded image data
Raises:
Exception: If any image fails to load
ValueError: If enable_frontend_decoding=True but nixl_connector is None
"""
image_futures = []
for item in image_mm_items:
if isinstance(item, dict) and URL_VARIANT_KEY in item:
# URL path: download and decode in Python backend
url = item[URL_VARIANT_KEY]
image_futures.append(self.load_image(url))
logger.debug(f"Preparing to load image from URL: {url[:80]}...")
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
if enable_frontend_decoding:
if nixl_connector is None:
logger.error(
"Frontend decoding enabled but nixl_connector not provided. "
"Caller must pass an initialized NIXL connector."
)
raise ValueError(
"nixl_connector required when enable_frontend_decoding=True"
)
metadata = item[DECODED_VARIANT_KEY]
image_futures.append(
read_decoded_media_via_nixl(nixl_connector, metadata)
)
else:
logger.error(
"Received Decoded multimodal data but enable_frontend_decoding=False. "
"Set enable_frontend_decoding=True to enable NIXL RDMA image transfer."
)
raise ValueError("Could not load decoded media from frontend")
# Process images in parallel
results = await asyncio.gather(*image_futures, return_exceptions=True)
loaded_images = []
collective_exceptions = ""
for media_item, result in zip(image_mm_items, results):
if isinstance(result, Exception):
source = media_item.get(URL_VARIANT_KEY, "decoded")
logger.error(f"Failed to load image from {source[:80]}...: {result}")
collective_exceptions += (
f"Failed to load image from {source[:80]}...: {result}\n"
)
continue
loaded_images.append(result)
if collective_exceptions:
raise Exception(collective_exceptions)
return loaded_images
......@@ -3,19 +3,32 @@
import asyncio
import logging
import threading
from dataclasses import asdict
from typing import Any, Dict, Optional, Union
import torch
from tensorrt_llm.inputs import default_multimodal_input_loader
import dynamo.nixl_connect as nixl_connect
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.trtllm.utils.disagg_utils import DisaggregatedParamsCodec
class EncodeHelper:
"""Utility class for encoding and serialization operations."""
# Shared ImageLoader for full EPD flow (async image loading)
_image_loader: Optional[ImageLoader] = None
_image_loader_lock = threading.Lock()
@classmethod
def _get_image_loader(cls) -> ImageLoader:
if cls._image_loader is None:
with cls._image_loader_lock:
if cls._image_loader is None:
cls._image_loader = ImageLoader()
return cls._image_loader
@staticmethod
def serialize_tensor_dict(tensor_dict: dict) -> dict:
"""Serialize a dictionary of tensors to JSON-serializable format.
......@@ -269,7 +282,7 @@ class EncodeHelper:
@staticmethod
async def _process_full_epd_flow(
text_prompt: str,
prompt_token_ids_from_request: list,
image_urls: list,
tokenizer,
model_dir: str,
......@@ -283,34 +296,33 @@ class EncodeHelper:
containing multimodal embedding handles for the prefill worker.
Args:
text_prompt: Text portion of the prompt
prompt_token_ids_from_request: token IDs from the request (Rust preprocessor)
image_urls: List of image URLs to process
tokenizer: Tokenizer for encoding the processed prompt
model_dir: Path to model directory (required for AutoProcessor)
model_type: Model type string (required for placeholder retrieval)
tokenizer: Tokenizer for decoding prompt_token_ids_from_request
model_dir: Path to model directory (unused; kept for API compatibility)
model_type: Model type string (unused; kept for API compatibility)
engine: TensorRTLLMEngine with MultimodalEncoder
Yields:
Response with ep_disaggregated_params, processed_prompt, and prompt_token_ids
"""
# NOTE: `default_multimodal_input_loader` requires `model_dir` to load the
# HuggingFace AutoProcessor (for chat template application) and as a fallback
# for tokenizer loading. `model_type` is needed to retrieve the correct
# multimodal placeholders and apply model-specific preprocessing.
# NOTE: default_multimodal_input_loader downloads images and preprocesses them
# synchronously. Wrap in asyncio.to_thread to allow concurrent image loading
# across multiple requests, improving throughput at high concurrency.
inputs = await asyncio.to_thread(
lambda: default_multimodal_input_loader(
tokenizer=tokenizer,
model_dir=model_dir,
model_type=model_type,
modality="image",
prompts=[text_prompt],
media=image_urls[0],
)
)
# Load images with shared ImageLoader (async, same as multimodal_processor PD flow).
image_items = [{"Url": u} for u in image_urls]
image_loader = EncodeHelper._get_image_loader()
pil_images = await image_loader.load_image_batch(image_items)
if not pil_images:
logging.error("ENCODE WORKER: no images loaded from image_urls")
yield {"ep_disaggregated_params": None}
return
processed_mm_data = {"image": pil_images}
inputs = [
{
"prompt_token_ids": prompt_token_ids_from_request,
"multi_modal_data": processed_mm_data,
"mm_processor_kwargs": {},
}
]
# NOTE: MultimodalEncoder.generate() is synchronous. Run it off-thread to avoid
# blocking the encode worker's event loop under concurrency.
......@@ -340,25 +352,16 @@ class EncodeHelper:
encoded_params = DisaggregatedParamsCodec.encode(ep_disaggregated_params)
params_dict = asdict(encoded_params)
# Extract processed prompt (includes <image> tokens) for prefill/decode consistency
# Extract processed prompt (includes <image> tokens) for prefill/decode consistency.
# NOTE: processed_prompt will contain template/placeholder tokens
# (e.g. <image>, [INST], etc.). Adding special tokens here can change
# token alignment across EPD stages (prefill/decode), so we explicitly
# avoid adding them.
processed_prompt = None
prompt_token_ids = None
if isinstance(inputs, list) and len(inputs) > 0:
first_input = inputs[0]
if isinstance(first_input, dict):
processed_prompt = first_input.get("prompt")
else:
processed_prompt = getattr(first_input, "prompt", None)
# Tokenize the processed prompt for prefill worker
if processed_prompt and tokenizer is not None:
# NOTE: processed_prompt already contains template/placeholder tokens
# (e.g. <image>, [INST], etc.). Adding special tokens here can change
# token alignment across EPD stages (prefill/decode), so we explicitly
# avoid adding them.
prompt_token_ids = tokenizer.encode(
processed_prompt, add_special_tokens=False
)
if tokenizer is not None:
processed_prompt = tokenizer.decode(
prompt_token_ids_from_request, skip_special_tokens=False
)
logging.debug(
"ENCODE WORKER: Extracted processed_prompt (len=%s)",
......@@ -368,7 +371,7 @@ class EncodeHelper:
yield {
"ep_disaggregated_params": params_dict,
"processed_prompt": processed_prompt,
"prompt_token_ids": prompt_token_ids,
"prompt_token_ids": prompt_token_ids_from_request,
}
@staticmethod
......@@ -407,7 +410,7 @@ class EncodeHelper:
"messages", request.get("messages", [])
)
(
text_prompt,
_,
image_urls,
embedding_paths,
) = multimodal_processor.extract_prompt_and_media(messages)
......@@ -423,7 +426,7 @@ class EncodeHelper:
yield response
# Flow 2: Full EPD flow (image URLs via MultimodalEncoder)
elif image_urls and text_prompt:
elif image_urls and request.get("token_ids"):
if model_dir is None or model_type is None:
yield {
"error": "model_dir and model_type are required for full EPD encode"
......@@ -432,11 +435,22 @@ class EncodeHelper:
if engine is None:
yield {"error": "No engine configured on encode worker for full EPD"}
return
# Use token_ids from request (Rust preprocessor already applied
# chat template and tokenized; token_ids then include image placeholder tokens
# if the model's tokenizer_config chat template emits them).
token_ids = request.get("token_ids")
async for response in EncodeHelper._process_full_epd_flow(
text_prompt, image_urls, tokenizer, model_dir, model_type, engine
token_ids,
image_urls,
tokenizer,
model_dir,
model_type,
engine,
):
yield response
# No valid multimodal content found
else:
yield {"error": "No embedding_paths or image_urls found in request"}
yield {
"error": "No embedding_paths or image_urls found in request, or image_urls without text_prompt or token_ids"
}
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import time
from io import BytesIO
......@@ -23,9 +22,9 @@ from urllib.parse import urlparse
from urllib.request import urlopen
import torch
from tensorrt_llm.inputs import default_multimodal_input_loader
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
......@@ -69,6 +68,8 @@ class MultimodalRequestProcessor:
else:
self.tokenizer = tokenizer_factory(model_dir)
self.image_loader = ImageLoader()
def is_url(self, path: str) -> bool:
"""Check if a path is a URL."""
parsed = urlparse(path)
......@@ -163,76 +164,150 @@ class MultimodalRequestProcessor:
else:
image_urls.append(url)
return " ".join(text_parts), image_urls, embedding_paths
return "".join(text_parts), image_urls, embedding_paths
async def process_openai_request(
self, request: Dict, embeddings: Any, ep_disaggregated_params: Any
) -> Optional[Any]:
"""Process OpenAI request and return with multimodal data."""
# Extract messages - check extra_args first (from Rust preprocessor for multimodal)
# Fall back to direct messages field for backward compatibility
"""
Process OpenAI request and return multimodal data in TokensPrompt format.
Supports three flows:
1. EPD Case 1: Encoder fully processed (has _epd_processed_prompt)
2. EPD Case 2: NIXL embeddings (embeddings parameter is not None)
3. PD Flow: Rust pre-tokenized with direct media loading
Returns dict compatible with TRT-LLM's generate_async:
{
"prompt_token_ids": List[int],
"multi_modal_data": Dict[str, List[torch.Tensor]]
}
or for EPD Case 1:
{
"prompt": str,
"prompt_token_ids": List[int]
}
"""
self.previous_decoded_text = ""
messages = request.get("extra_args", {}).get(
"messages", request.get("messages", [])
)
text_prompt, image_urls, embedding_paths = self.extract_prompt_and_media(
messages
)
if not image_urls and not embedding_paths and not ep_disaggregated_params:
logging.warning("No multimodal content, returning None")
return None
# EPD Flow Case 1: Encoder has fully processed the prompt
# The encode worker has done everything: vision encoding, prompt processing, tokenization
# Return the encoder's processed prompt and tokens directly
processed_prompt_from_encoder = request.get("_epd_processed_prompt")
# Only use EPD flow if we actually have encoder data
# For PD flow (no encoder), fall through to embedding_paths handling
if processed_prompt_from_encoder is not None:
text_prompt = processed_prompt_from_encoder
result = {"prompt": text_prompt}
logging.info("MM: Using fully processed prompt from encoder")
result = {"prompt": processed_prompt_from_encoder}
prompt_token_ids = request.get("_epd_prompt_token_ids")
if prompt_token_ids:
result["prompt_token_ids"] = prompt_token_ids
else:
logging.warning("MM PROCESSOR: No prompt_token_ids from encoder")
logging.warning("MM: No prompt_token_ids from encoder")
return result
loader_kwargs = {}
# Get token_ids from request (already tokenized by Rust frontend)
token_ids = request.get("token_ids")
if not token_ids:
logging.warning("No token_ids in request")
return None
# Initialize result in TokensPrompt format
# mm_processor_kwargs must be a dict (not None) for TRT-LLM's processor
processed_inputs = {"prompt_token_ids": token_ids, "mm_processor_kwargs": {}}
# EPD Flow Case 2: Embeddings received via NIXL from encode worker
# The encode worker computed vision embeddings and transferred them via RDMA/NIXL
# We need to pass these embeddings directly to TRT-LLM's generate_async
if embeddings is not None:
# EPD flow - embeddings received from encode worker via NIXL
loader_kwargs["mm_embeddings"] = [embeddings]
logging.info(
f"Using NIXL embeddings: shape={embeddings.shape if hasattr(embeddings, 'shape') else 'N/A'}"
)
elif image_urls:
# Image-only flow
loader_kwargs["media"] = [image_urls]
elif embedding_paths:
# PD flow with no NIXL and no encoder
loader_kwargs["mm_embeddings"] = [
self.load_tensor_from_path_or_url(path) for path in embedding_paths
]
logging.info(f"Using embedding paths: {embedding_paths}")
# NOTE: default_multimodal_input_loader downloads images and preprocesses them
# synchronously. Wrap in asyncio.to_thread to allow concurrent image loading
# across multiple requests, improving throughput at high concurrency.
processed_inputs = await asyncio.to_thread(
lambda: default_multimodal_input_loader(
tokenizer=self.tokenizer,
model_dir=self.model_dir,
model_type=self.model_type,
modality=self.modality,
prompts=[text_prompt],
image_data_format="pt",
device="cuda",
**loader_kwargs,
f"Using NIXL embeddings from encoder: shape={embeddings.shape if hasattr(embeddings, 'shape') else 'N/A'}"
)
)
# Return the first processed input if available
if processed_inputs:
return processed_inputs[0]
# Structure embeddings in the format TRT-LLM's generate_async expects
processed_inputs["multi_modal_embeddings"] = embeddings
return processed_inputs
# PD Flow: Pre-tokenized by Rust frontend with direct media loading
# TODO: Add frontend decoding support
# Handle multimodal data if present
multi_modal_data = request.get("multi_modal_data")
if multi_modal_data and isinstance(multi_modal_data, dict):
processed_mm_data = {}
# Process images and embedding paths from image_url field
image_items = multi_modal_data.get("image_url", [])
if image_items and isinstance(image_items, list):
# Separate embedding paths from regular image URLs
# Items come from Rust in format: {"Url": "..."} or {"Decoded": ...}
embedding_paths = []
image_urls = []
for item in image_items:
# Extract URL from item (Rust enum serialization uses "Url" with capital U)
if isinstance(item, dict) and "Url" in item:
url = item["Url"]
elif isinstance(item, dict) and "Decoded" in item:
# Already decoded data (NIXL) - always treat as image
image_urls.append(item)
continue
elif isinstance(item, str):
# Fallback for string URLs (backward compatibility)
url = item
else:
logging.warning(
f"Unexpected item format in image_items: {item}"
)
continue
# Check if this is an embedding file based on extension
if url.endswith((".pt", ".pth", ".bin")):
embedding_paths.append(url)
else:
# Keep original item format for load_image_batch
image_urls.append(
item if isinstance(item, dict) else {"Url": item}
)
# Load regular images as PIL Images for TRT-LLM's input processor
# TRT-LLM will auto-detect this and compute mrope_config
if image_urls:
try:
pil_images = await self.image_loader.load_image_batch(
image_urls
)
if pil_images:
processed_mm_data["image"] = pil_images
logging.info(
f"Loaded {len(pil_images)} image(s) as PIL Images"
)
except Exception as e:
logging.error(f"Failed to load images: {e}")
return None
# Load embedding files (.pt, .pth, .bin) for PD flow
# These are pre-computed vision encoder outputs
if embedding_paths:
try:
loaded_embeddings = [
self.load_tensor_from_path_or_url(path)
for path in embedding_paths
]
if loaded_embeddings:
processed_mm_data["embedding"] = loaded_embeddings
logging.info(
f"Loaded {len(loaded_embeddings)} embedding file(s) from paths: {embedding_paths}"
)
except Exception as e:
logging.error(f"Failed to load embeddings: {e}")
return None
# TODO: Add support for video_url, audio_url
if processed_mm_data:
processed_inputs["multi_modal_data"] = processed_mm_data
return None
return processed_inputs
def create_response_chunk(
self,
......
......@@ -520,7 +520,16 @@ class HandlerBase(BaseGenerativeHandler):
if processed_input:
return processed_input
# Fallback: text-only flow
# If multimodal processing returned None but request has multimodal data,
# this is an error (not a text-only request). Raise instead of falling back.
if request.get("multi_modal_data"):
raise RuntimeError(
"Failed to process multimodal request. Check server logs for details. "
"Common issues: missing allowed_local_media_path configuration, "
"file not found, or file outside allowed directory."
)
# Fallback: text-only flow (no multimodal processor or no multimodal data)
return request.get("token_ids")
def _normalize_request_format(self, request: dict) -> None:
......
......@@ -23,9 +23,9 @@ from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError
import dynamo.nixl_connect as nixl_connect
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.llm import (
KvEventPublisher,
......@@ -39,7 +39,6 @@ from dynamo.runtime.logging import configure_dynamo_logging
from .engine_monitor import VllmEngineMonitor
from .multimodal_utils.hash_utils import compute_mm_uuids_from_images
from .multimodal_utils.image_loader import ImageLoader
# Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url"
......@@ -904,68 +903,6 @@ class BaseWorkerHandler(ABC):
return prompt, sequence_length, embeddings_tensor
async def _load_image_batch(
self, image_mm_items: list[Dict[str, Any]]
) -> list[Any]:
"""
Load a batch of images from multimodal data items.
Supports two paths:
1. Url variant: Download and decode image from URL (default)
2. Decoded variant: Read pre-decoded image via NIXL RDMA (requires --frontend-decoding)
Args:
image_mm_items: List of multimodal data items for images
Returns:
List of loaded image data
Raises:
Exception: If any image fails to load
"""
image_futures = []
for item in image_mm_items:
if isinstance(item, dict) and URL_VARIANT_KEY in item:
# URL path: download and decode in Python backend
url = item[URL_VARIANT_KEY]
image_futures.append(self.image_loader.load_image(url))
logger.debug(f"Preparing to load image from URL: {url[:80]}...")
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
if self.enable_frontend_decoding:
async with self._nixl_connector_lock:
if self._nixl_connector is None:
self._nixl_connector = nixl_connect.Connector()
await self._nixl_connector.initialize()
metadata = item[DECODED_VARIANT_KEY]
image_futures.append(
read_decoded_media_via_nixl(self._nixl_connector, metadata)
)
else:
logger.error(
"Received Decoded multimodal data but --frontend-decoding not enabled. "
"Use --frontend-decoding flag to enable NIXL RDMA image transfer."
)
raise ValueError("Could not load decoded media from frontend")
# Process images in parallel
results = await asyncio.gather(*image_futures, return_exceptions=True)
loaded_images = []
collective_exceptions = ""
for media_item, result in zip(image_mm_items, results):
if isinstance(result, Exception):
source = media_item.get(URL_VARIANT_KEY, "decoded")
logger.error(f"Failed to load image from {source[:80]}...: {result}")
collective_exceptions += (
f"Failed to load image from {source[:80]}...: {result}\n"
)
continue
loaded_images.append(result)
if collective_exceptions:
raise Exception(collective_exceptions)
return loaded_images
async def _extract_multimodal_data(
self, request: Dict[str, Any]
) -> Dict[str, Any] | None:
......@@ -985,8 +922,19 @@ class BaseWorkerHandler(ABC):
mm_map = request["multi_modal_data"]
vllm_mm_data = {}
# Lazy-init NIXL connector only when frontend decoding is enabled
if self.enable_frontend_decoding:
async with self._nixl_connector_lock:
if self._nixl_connector is None:
self._nixl_connector = nixl_connect.Connector()
await self._nixl_connector.initialize()
# Process image_url entries
images = await self._load_image_batch(mm_map.get(IMAGE_URL_KEY, []))
images = await self.image_loader.load_image_batch(
mm_map.get(IMAGE_URL_KEY, []),
enable_frontend_decoding=self.enable_frontend_decoding,
nixl_connector=self._nixl_connector,
)
if images:
# vLLM expects single image or list
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.common.multimodal.http_client import get_http_client
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.vllm.multimodal_utils.chat_message_utils import extract_user_text
from dynamo.vllm.multimodal_utils.chat_processor import (
ChatProcessor,
......@@ -12,8 +14,6 @@ from dynamo.vllm.multimodal_utils.encode_utils import (
get_embedding_hash,
get_encoder_components,
)
from dynamo.vllm.multimodal_utils.http_client import get_http_client
from dynamo.vllm.multimodal_utils.image_loader import ImageLoader
from dynamo.vllm.multimodal_utils.model import (
SupportedModels,
construct_mm_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment