Unverified Commit 8a098a66 authored by Michal Guzek's avatar Michal Guzek Committed by GitHub
Browse files

feat: TRT-LLM multimodal preprocessor with backend media decoding (#5910)


Signed-off-by: default avatarMichal Guzek <mguzek@nvidia.com>
parent 9bdc8b73
...@@ -11,9 +11,11 @@ from dynamo.common.multimodal.embedding_transfer import ( ...@@ -11,9 +11,11 @@ from dynamo.common.multimodal.embedding_transfer import (
NixlPersistentEmbeddingSender, NixlPersistentEmbeddingSender,
TransferRequest, TransferRequest,
) )
from dynamo.common.multimodal.image_loader import ImageLoader
__all__ = [ __all__ = [
"AsyncEncoderCache", "AsyncEncoderCache",
"ImageLoader",
"NixlPersistentEmbeddingReceiver", "NixlPersistentEmbeddingReceiver",
"NixlPersistentEmbeddingSender", "NixlPersistentEmbeddingSender",
"TransferRequest", "TransferRequest",
......
...@@ -18,15 +18,23 @@ import base64 ...@@ -18,15 +18,23 @@ import base64
import binascii import binascii
import logging import logging
from io import BytesIO from io import BytesIO
from typing import Any, Dict, Final, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
from PIL import Image from PIL import Image
import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from .http_client import get_http_client from .http_client import get_http_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Constants for multimodal data variants
URL_VARIANT_KEY: Final = "Url"
DECODED_VARIANT_KEY: Final = "Decoded"
class ImageLoader: class ImageLoader:
CACHE_SIZE_MAXIMUM = 8 CACHE_SIZE_MAXIMUM = 8
...@@ -108,3 +116,77 @@ class ImageLoader: ...@@ -108,3 +116,77 @@ class ImageLoader:
except Exception as e: except Exception as e:
logger.error(f"Error loading image: {e}") logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}") raise ValueError(f"Failed to load image: {e}")
async def load_image_batch(
self,
image_mm_items: List[Dict[str, Any]],
enable_frontend_decoding: bool = False,
nixl_connector: Optional["nixl_connect.Connector"] = None,
) -> List[Any]:
"""
Load a batch of images from multimodal data items.
Supports two paths:
1. Url variant: Download and decode image from URL (default)
2. Decoded variant: Read pre-decoded image via NIXL RDMA (requires enable_frontend_decoding=True)
Args:
image_mm_items: List of multimodal data items for images
enable_frontend_decoding: If True, enables NIXL RDMA for decoded images
nixl_connector: NIXL connector for frontend decoding (required if enable_frontend_decoding=True)
Returns:
List of loaded image data
Raises:
Exception: If any image fails to load
ValueError: If enable_frontend_decoding=True but nixl_connector is None
"""
image_futures = []
for item in image_mm_items:
if isinstance(item, dict) and URL_VARIANT_KEY in item:
# URL path: download and decode in Python backend
url = item[URL_VARIANT_KEY]
image_futures.append(self.load_image(url))
logger.debug(f"Preparing to load image from URL: {url[:80]}...")
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
if enable_frontend_decoding:
if nixl_connector is None:
logger.error(
"Frontend decoding enabled but nixl_connector not provided. "
"Caller must pass an initialized NIXL connector."
)
raise ValueError(
"nixl_connector required when enable_frontend_decoding=True"
)
metadata = item[DECODED_VARIANT_KEY]
image_futures.append(
read_decoded_media_via_nixl(nixl_connector, metadata)
)
else:
logger.error(
"Received Decoded multimodal data but enable_frontend_decoding=False. "
"Set enable_frontend_decoding=True to enable NIXL RDMA image transfer."
)
raise ValueError("Could not load decoded media from frontend")
# Process images in parallel
results = await asyncio.gather(*image_futures, return_exceptions=True)
loaded_images = []
collective_exceptions = ""
for media_item, result in zip(image_mm_items, results):
if isinstance(result, Exception):
source = media_item.get(URL_VARIANT_KEY, "decoded")
logger.error(f"Failed to load image from {source[:80]}...: {result}")
collective_exceptions += (
f"Failed to load image from {source[:80]}...: {result}\n"
)
continue
loaded_images.append(result)
if collective_exceptions:
raise Exception(collective_exceptions)
return loaded_images
...@@ -3,19 +3,32 @@ ...@@ -3,19 +3,32 @@
import asyncio import asyncio
import logging import logging
import threading
from dataclasses import asdict from dataclasses import asdict
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import torch import torch
from tensorrt_llm.inputs import default_multimodal_input_loader
import dynamo.nixl_connect as nixl_connect import dynamo.nixl_connect as nixl_connect
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.trtllm.utils.disagg_utils import DisaggregatedParamsCodec from dynamo.trtllm.utils.disagg_utils import DisaggregatedParamsCodec
class EncodeHelper: class EncodeHelper:
"""Utility class for encoding and serialization operations.""" """Utility class for encoding and serialization operations."""
# Shared ImageLoader for full EPD flow (async image loading)
_image_loader: Optional[ImageLoader] = None
_image_loader_lock = threading.Lock()
@classmethod
def _get_image_loader(cls) -> ImageLoader:
if cls._image_loader is None:
with cls._image_loader_lock:
if cls._image_loader is None:
cls._image_loader = ImageLoader()
return cls._image_loader
@staticmethod @staticmethod
def serialize_tensor_dict(tensor_dict: dict) -> dict: def serialize_tensor_dict(tensor_dict: dict) -> dict:
"""Serialize a dictionary of tensors to JSON-serializable format. """Serialize a dictionary of tensors to JSON-serializable format.
...@@ -269,7 +282,7 @@ class EncodeHelper: ...@@ -269,7 +282,7 @@ class EncodeHelper:
@staticmethod @staticmethod
async def _process_full_epd_flow( async def _process_full_epd_flow(
text_prompt: str, prompt_token_ids_from_request: list,
image_urls: list, image_urls: list,
tokenizer, tokenizer,
model_dir: str, model_dir: str,
...@@ -283,34 +296,33 @@ class EncodeHelper: ...@@ -283,34 +296,33 @@ class EncodeHelper:
containing multimodal embedding handles for the prefill worker. containing multimodal embedding handles for the prefill worker.
Args: Args:
text_prompt: Text portion of the prompt prompt_token_ids_from_request: token IDs from the request (Rust preprocessor)
image_urls: List of image URLs to process image_urls: List of image URLs to process
tokenizer: Tokenizer for encoding the processed prompt tokenizer: Tokenizer for decoding prompt_token_ids_from_request
model_dir: Path to model directory (required for AutoProcessor) model_dir: Path to model directory (unused; kept for API compatibility)
model_type: Model type string (required for placeholder retrieval) model_type: Model type string (unused; kept for API compatibility)
engine: TensorRTLLMEngine with MultimodalEncoder engine: TensorRTLLMEngine with MultimodalEncoder
Yields: Yields:
Response with ep_disaggregated_params, processed_prompt, and prompt_token_ids Response with ep_disaggregated_params, processed_prompt, and prompt_token_ids
""" """
# NOTE: `default_multimodal_input_loader` requires `model_dir` to load the # Load images with shared ImageLoader (async, same as multimodal_processor PD flow).
# HuggingFace AutoProcessor (for chat template application) and as a fallback image_items = [{"Url": u} for u in image_urls]
# for tokenizer loading. `model_type` is needed to retrieve the correct image_loader = EncodeHelper._get_image_loader()
# multimodal placeholders and apply model-specific preprocessing. pil_images = await image_loader.load_image_batch(image_items)
if not pil_images:
# NOTE: default_multimodal_input_loader downloads images and preprocesses them logging.error("ENCODE WORKER: no images loaded from image_urls")
# synchronously. Wrap in asyncio.to_thread to allow concurrent image loading yield {"ep_disaggregated_params": None}
# across multiple requests, improving throughput at high concurrency. return
inputs = await asyncio.to_thread(
lambda: default_multimodal_input_loader( processed_mm_data = {"image": pil_images}
tokenizer=tokenizer, inputs = [
model_dir=model_dir, {
model_type=model_type, "prompt_token_ids": prompt_token_ids_from_request,
modality="image", "multi_modal_data": processed_mm_data,
prompts=[text_prompt], "mm_processor_kwargs": {},
media=image_urls[0], }
) ]
)
# NOTE: MultimodalEncoder.generate() is synchronous. Run it off-thread to avoid # NOTE: MultimodalEncoder.generate() is synchronous. Run it off-thread to avoid
# blocking the encode worker's event loop under concurrency. # blocking the encode worker's event loop under concurrency.
...@@ -340,24 +352,15 @@ class EncodeHelper: ...@@ -340,24 +352,15 @@ class EncodeHelper:
encoded_params = DisaggregatedParamsCodec.encode(ep_disaggregated_params) encoded_params = DisaggregatedParamsCodec.encode(ep_disaggregated_params)
params_dict = asdict(encoded_params) params_dict = asdict(encoded_params)
# Extract processed prompt (includes <image> tokens) for prefill/decode consistency # Extract processed prompt (includes <image> tokens) for prefill/decode consistency.
processed_prompt = None # NOTE: processed_prompt will contain template/placeholder tokens
prompt_token_ids = None
if isinstance(inputs, list) and len(inputs) > 0:
first_input = inputs[0]
if isinstance(first_input, dict):
processed_prompt = first_input.get("prompt")
else:
processed_prompt = getattr(first_input, "prompt", None)
# Tokenize the processed prompt for prefill worker
if processed_prompt and tokenizer is not None:
# NOTE: processed_prompt already contains template/placeholder tokens
# (e.g. <image>, [INST], etc.). Adding special tokens here can change # (e.g. <image>, [INST], etc.). Adding special tokens here can change
# token alignment across EPD stages (prefill/decode), so we explicitly # token alignment across EPD stages (prefill/decode), so we explicitly
# avoid adding them. # avoid adding them.
prompt_token_ids = tokenizer.encode( processed_prompt = None
processed_prompt, add_special_tokens=False if tokenizer is not None:
processed_prompt = tokenizer.decode(
prompt_token_ids_from_request, skip_special_tokens=False
) )
logging.debug( logging.debug(
...@@ -368,7 +371,7 @@ class EncodeHelper: ...@@ -368,7 +371,7 @@ class EncodeHelper:
yield { yield {
"ep_disaggregated_params": params_dict, "ep_disaggregated_params": params_dict,
"processed_prompt": processed_prompt, "processed_prompt": processed_prompt,
"prompt_token_ids": prompt_token_ids, "prompt_token_ids": prompt_token_ids_from_request,
} }
@staticmethod @staticmethod
...@@ -407,7 +410,7 @@ class EncodeHelper: ...@@ -407,7 +410,7 @@ class EncodeHelper:
"messages", request.get("messages", []) "messages", request.get("messages", [])
) )
( (
text_prompt, _,
image_urls, image_urls,
embedding_paths, embedding_paths,
) = multimodal_processor.extract_prompt_and_media(messages) ) = multimodal_processor.extract_prompt_and_media(messages)
...@@ -423,7 +426,7 @@ class EncodeHelper: ...@@ -423,7 +426,7 @@ class EncodeHelper:
yield response yield response
# Flow 2: Full EPD flow (image URLs via MultimodalEncoder) # Flow 2: Full EPD flow (image URLs via MultimodalEncoder)
elif image_urls and text_prompt: elif image_urls and request.get("token_ids"):
if model_dir is None or model_type is None: if model_dir is None or model_type is None:
yield { yield {
"error": "model_dir and model_type are required for full EPD encode" "error": "model_dir and model_type are required for full EPD encode"
...@@ -432,11 +435,22 @@ class EncodeHelper: ...@@ -432,11 +435,22 @@ class EncodeHelper:
if engine is None: if engine is None:
yield {"error": "No engine configured on encode worker for full EPD"} yield {"error": "No engine configured on encode worker for full EPD"}
return return
# Use token_ids from request (Rust preprocessor already applied
# chat template and tokenized; token_ids then include image placeholder tokens
# if the model's tokenizer_config chat template emits them).
token_ids = request.get("token_ids")
async for response in EncodeHelper._process_full_epd_flow( async for response in EncodeHelper._process_full_epd_flow(
text_prompt, image_urls, tokenizer, model_dir, model_type, engine token_ids,
image_urls,
tokenizer,
model_dir,
model_type,
engine,
): ):
yield response yield response
# No valid multimodal content found # No valid multimodal content found
else: else:
yield {"error": "No embedding_paths or image_urls found in request"} yield {
"error": "No embedding_paths or image_urls found in request, or image_urls without text_prompt or token_ids"
}
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging import logging
import time import time
from io import BytesIO from io import BytesIO
...@@ -23,9 +22,9 @@ from urllib.parse import urlparse ...@@ -23,9 +22,9 @@ from urllib.parse import urlparse
from urllib.request import urlopen from urllib.request import urlopen
import torch import torch
from tensorrt_llm.inputs import default_multimodal_input_loader
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging() configure_dynamo_logging()
...@@ -69,6 +68,8 @@ class MultimodalRequestProcessor: ...@@ -69,6 +68,8 @@ class MultimodalRequestProcessor:
else: else:
self.tokenizer = tokenizer_factory(model_dir) self.tokenizer = tokenizer_factory(model_dir)
self.image_loader = ImageLoader()
def is_url(self, path: str) -> bool: def is_url(self, path: str) -> bool:
"""Check if a path is a URL.""" """Check if a path is a URL."""
parsed = urlparse(path) parsed = urlparse(path)
...@@ -163,77 +164,151 @@ class MultimodalRequestProcessor: ...@@ -163,77 +164,151 @@ class MultimodalRequestProcessor:
else: else:
image_urls.append(url) image_urls.append(url)
return " ".join(text_parts), image_urls, embedding_paths return "".join(text_parts), image_urls, embedding_paths
async def process_openai_request( async def process_openai_request(
self, request: Dict, embeddings: Any, ep_disaggregated_params: Any self, request: Dict, embeddings: Any, ep_disaggregated_params: Any
) -> Optional[Any]: ) -> Optional[Any]:
"""Process OpenAI request and return with multimodal data.""" """
# Extract messages - check extra_args first (from Rust preprocessor for multimodal) Process OpenAI request and return multimodal data in TokensPrompt format.
# Fall back to direct messages field for backward compatibility
Supports three flows:
1. EPD Case 1: Encoder fully processed (has _epd_processed_prompt)
2. EPD Case 2: NIXL embeddings (embeddings parameter is not None)
3. PD Flow: Rust pre-tokenized with direct media loading
Returns dict compatible with TRT-LLM's generate_async:
{
"prompt_token_ids": List[int],
"multi_modal_data": Dict[str, List[torch.Tensor]]
}
or for EPD Case 1:
{
"prompt": str,
"prompt_token_ids": List[int]
}
"""
self.previous_decoded_text = "" self.previous_decoded_text = ""
messages = request.get("extra_args", {}).get(
"messages", request.get("messages", [])
)
text_prompt, image_urls, embedding_paths = self.extract_prompt_and_media(
messages
)
if not image_urls and not embedding_paths and not ep_disaggregated_params:
logging.warning("No multimodal content, returning None")
return None
# EPD Flow Case 1: Encoder has fully processed the prompt
# The encode worker has done everything: vision encoding, prompt processing, tokenization
# Return the encoder's processed prompt and tokens directly
processed_prompt_from_encoder = request.get("_epd_processed_prompt") processed_prompt_from_encoder = request.get("_epd_processed_prompt")
# Only use EPD flow if we actually have encoder data
# For PD flow (no encoder), fall through to embedding_paths handling
if processed_prompt_from_encoder is not None: if processed_prompt_from_encoder is not None:
text_prompt = processed_prompt_from_encoder logging.info("MM: Using fully processed prompt from encoder")
result = {"prompt": text_prompt} result = {"prompt": processed_prompt_from_encoder}
prompt_token_ids = request.get("_epd_prompt_token_ids") prompt_token_ids = request.get("_epd_prompt_token_ids")
if prompt_token_ids: if prompt_token_ids:
result["prompt_token_ids"] = prompt_token_ids result["prompt_token_ids"] = prompt_token_ids
else: else:
logging.warning("MM PROCESSOR: No prompt_token_ids from encoder") logging.warning("MM: No prompt_token_ids from encoder")
return result return result
loader_kwargs = {}
# Get token_ids from request (already tokenized by Rust frontend)
token_ids = request.get("token_ids")
if not token_ids:
logging.warning("No token_ids in request")
return None
# Initialize result in TokensPrompt format
# mm_processor_kwargs must be a dict (not None) for TRT-LLM's processor
processed_inputs = {"prompt_token_ids": token_ids, "mm_processor_kwargs": {}}
# EPD Flow Case 2: Embeddings received via NIXL from encode worker
# The encode worker computed vision embeddings and transferred them via RDMA/NIXL
# We need to pass these embeddings directly to TRT-LLM's generate_async
if embeddings is not None: if embeddings is not None:
# EPD flow - embeddings received from encode worker via NIXL
loader_kwargs["mm_embeddings"] = [embeddings]
logging.info( logging.info(
f"Using NIXL embeddings: shape={embeddings.shape if hasattr(embeddings, 'shape') else 'N/A'}" f"Using NIXL embeddings from encoder: shape={embeddings.shape if hasattr(embeddings, 'shape') else 'N/A'}"
) )
elif image_urls:
# Image-only flow # Structure embeddings in the format TRT-LLM's generate_async expects
loader_kwargs["media"] = [image_urls] processed_inputs["multi_modal_embeddings"] = embeddings
elif embedding_paths:
# PD flow with no NIXL and no encoder return processed_inputs
loader_kwargs["mm_embeddings"] = [
self.load_tensor_from_path_or_url(path) for path in embedding_paths # PD Flow: Pre-tokenized by Rust frontend with direct media loading
] # TODO: Add frontend decoding support
logging.info(f"Using embedding paths: {embedding_paths}")
# Handle multimodal data if present
# NOTE: default_multimodal_input_loader downloads images and preprocesses them multi_modal_data = request.get("multi_modal_data")
# synchronously. Wrap in asyncio.to_thread to allow concurrent image loading if multi_modal_data and isinstance(multi_modal_data, dict):
# across multiple requests, improving throughput at high concurrency. processed_mm_data = {}
processed_inputs = await asyncio.to_thread(
lambda: default_multimodal_input_loader( # Process images and embedding paths from image_url field
tokenizer=self.tokenizer, image_items = multi_modal_data.get("image_url", [])
model_dir=self.model_dir, if image_items and isinstance(image_items, list):
model_type=self.model_type, # Separate embedding paths from regular image URLs
modality=self.modality, # Items come from Rust in format: {"Url": "..."} or {"Decoded": ...}
prompts=[text_prompt], embedding_paths = []
image_data_format="pt", image_urls = []
device="cuda",
**loader_kwargs, for item in image_items:
# Extract URL from item (Rust enum serialization uses "Url" with capital U)
if isinstance(item, dict) and "Url" in item:
url = item["Url"]
elif isinstance(item, dict) and "Decoded" in item:
# Already decoded data (NIXL) - always treat as image
image_urls.append(item)
continue
elif isinstance(item, str):
# Fallback for string URLs (backward compatibility)
url = item
else:
logging.warning(
f"Unexpected item format in image_items: {item}"
) )
continue
# Check if this is an embedding file based on extension
if url.endswith((".pt", ".pth", ".bin")):
embedding_paths.append(url)
else:
# Keep original item format for load_image_batch
image_urls.append(
item if isinstance(item, dict) else {"Url": item}
) )
# Return the first processed input if available # Load regular images as PIL Images for TRT-LLM's input processor
if processed_inputs: # TRT-LLM will auto-detect this and compute mrope_config
return processed_inputs[0] if image_urls:
try:
pil_images = await self.image_loader.load_image_batch(
image_urls
)
if pil_images:
processed_mm_data["image"] = pil_images
logging.info(
f"Loaded {len(pil_images)} image(s) as PIL Images"
)
except Exception as e:
logging.error(f"Failed to load images: {e}")
return None
# Load embedding files (.pt, .pth, .bin) for PD flow
# These are pre-computed vision encoder outputs
if embedding_paths:
try:
loaded_embeddings = [
self.load_tensor_from_path_or_url(path)
for path in embedding_paths
]
if loaded_embeddings:
processed_mm_data["embedding"] = loaded_embeddings
logging.info(
f"Loaded {len(loaded_embeddings)} embedding file(s) from paths: {embedding_paths}"
)
except Exception as e:
logging.error(f"Failed to load embeddings: {e}")
return None return None
# TODO: Add support for video_url, audio_url
if processed_mm_data:
processed_inputs["multi_modal_data"] = processed_mm_data
return processed_inputs
def create_response_chunk( def create_response_chunk(
self, self,
output: Any, output: Any,
......
...@@ -520,7 +520,16 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -520,7 +520,16 @@ class HandlerBase(BaseGenerativeHandler):
if processed_input: if processed_input:
return processed_input return processed_input
# Fallback: text-only flow # If multimodal processing returned None but request has multimodal data,
# this is an error (not a text-only request). Raise instead of falling back.
if request.get("multi_modal_data"):
raise RuntimeError(
"Failed to process multimodal request. Check server logs for details. "
"Common issues: missing allowed_local_media_path configuration, "
"file not found, or file outside allowed directory."
)
# Fallback: text-only flow (no multimodal processor or no multimodal data)
return request.get("token_ids") return request.get("token_ids")
def _normalize_request_format(self, request: dict) -> None: def _normalize_request_format(self, request: dict) -> None:
......
...@@ -23,9 +23,9 @@ from vllm.sampling_params import SamplingParams, StructuredOutputsParams ...@@ -23,9 +23,9 @@ from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
import dynamo.nixl_connect as nixl_connect import dynamo.nixl_connect as nixl_connect
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.otel_tracing import build_trace_headers from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.llm import ( from dynamo.llm import (
KvEventPublisher, KvEventPublisher,
...@@ -39,7 +39,6 @@ from dynamo.runtime.logging import configure_dynamo_logging ...@@ -39,7 +39,6 @@ from dynamo.runtime.logging import configure_dynamo_logging
from .engine_monitor import VllmEngineMonitor from .engine_monitor import VllmEngineMonitor
from .multimodal_utils.hash_utils import compute_mm_uuids_from_images from .multimodal_utils.hash_utils import compute_mm_uuids_from_images
from .multimodal_utils.image_loader import ImageLoader
# Multimodal data dictionary keys # Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url" IMAGE_URL_KEY: Final = "image_url"
...@@ -904,68 +903,6 @@ class BaseWorkerHandler(ABC): ...@@ -904,68 +903,6 @@ class BaseWorkerHandler(ABC):
return prompt, sequence_length, embeddings_tensor return prompt, sequence_length, embeddings_tensor
async def _load_image_batch(
self, image_mm_items: list[Dict[str, Any]]
) -> list[Any]:
"""
Load a batch of images from multimodal data items.
Supports two paths:
1. Url variant: Download and decode image from URL (default)
2. Decoded variant: Read pre-decoded image via NIXL RDMA (requires --frontend-decoding)
Args:
image_mm_items: List of multimodal data items for images
Returns:
List of loaded image data
Raises:
Exception: If any image fails to load
"""
image_futures = []
for item in image_mm_items:
if isinstance(item, dict) and URL_VARIANT_KEY in item:
# URL path: download and decode in Python backend
url = item[URL_VARIANT_KEY]
image_futures.append(self.image_loader.load_image(url))
logger.debug(f"Preparing to load image from URL: {url[:80]}...")
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
if self.enable_frontend_decoding:
async with self._nixl_connector_lock:
if self._nixl_connector is None:
self._nixl_connector = nixl_connect.Connector()
await self._nixl_connector.initialize()
metadata = item[DECODED_VARIANT_KEY]
image_futures.append(
read_decoded_media_via_nixl(self._nixl_connector, metadata)
)
else:
logger.error(
"Received Decoded multimodal data but --frontend-decoding not enabled. "
"Use --frontend-decoding flag to enable NIXL RDMA image transfer."
)
raise ValueError("Could not load decoded media from frontend")
# Process images in parallel
results = await asyncio.gather(*image_futures, return_exceptions=True)
loaded_images = []
collective_exceptions = ""
for media_item, result in zip(image_mm_items, results):
if isinstance(result, Exception):
source = media_item.get(URL_VARIANT_KEY, "decoded")
logger.error(f"Failed to load image from {source[:80]}...: {result}")
collective_exceptions += (
f"Failed to load image from {source[:80]}...: {result}\n"
)
continue
loaded_images.append(result)
if collective_exceptions:
raise Exception(collective_exceptions)
return loaded_images
async def _extract_multimodal_data( async def _extract_multimodal_data(
self, request: Dict[str, Any] self, request: Dict[str, Any]
) -> Dict[str, Any] | None: ) -> Dict[str, Any] | None:
...@@ -985,8 +922,19 @@ class BaseWorkerHandler(ABC): ...@@ -985,8 +922,19 @@ class BaseWorkerHandler(ABC):
mm_map = request["multi_modal_data"] mm_map = request["multi_modal_data"]
vllm_mm_data = {} vllm_mm_data = {}
# Lazy-init NIXL connector only when frontend decoding is enabled
if self.enable_frontend_decoding:
async with self._nixl_connector_lock:
if self._nixl_connector is None:
self._nixl_connector = nixl_connect.Connector()
await self._nixl_connector.initialize()
# Process image_url entries # Process image_url entries
images = await self._load_image_batch(mm_map.get(IMAGE_URL_KEY, [])) images = await self.image_loader.load_image_batch(
mm_map.get(IMAGE_URL_KEY, []),
enable_frontend_decoding=self.enable_frontend_decoding,
nixl_connector=self._nixl_connector,
)
if images: if images:
# vLLM expects single image or list # vLLM expects single image or list
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dynamo.common.multimodal.http_client import get_http_client
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.vllm.multimodal_utils.chat_message_utils import extract_user_text from dynamo.vllm.multimodal_utils.chat_message_utils import extract_user_text
from dynamo.vllm.multimodal_utils.chat_processor import ( from dynamo.vllm.multimodal_utils.chat_processor import (
ChatProcessor, ChatProcessor,
...@@ -12,8 +14,6 @@ from dynamo.vllm.multimodal_utils.encode_utils import ( ...@@ -12,8 +14,6 @@ from dynamo.vllm.multimodal_utils.encode_utils import (
get_embedding_hash, get_embedding_hash,
get_encoder_components, get_encoder_components,
) )
from dynamo.vllm.multimodal_utils.http_client import get_http_client
from dynamo.vllm.multimodal_utils.image_loader import ImageLoader
from dynamo.vllm.multimodal_utils.model import ( from dynamo.vllm.multimodal_utils.model import (
SupportedModels, SupportedModels,
construct_mm_data, construct_mm_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment