Unverified Commit bb2ca1a9 authored by Michal Guzek's avatar Michal Guzek Committed by GitHub
Browse files

fix: TRT-LLM multimodal preprocessor - revert to the old...


fix: TRT-LLM multimodal preprocessor - revert to the old default_multimodal_input_loader for the embeddings case (#6840)
Signed-off-by: default avatarMichal Guzek <mguzek@nvidia.com>
parent d993f9d3
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging import logging
import time import time
from io import BytesIO from io import BytesIO
...@@ -22,6 +23,7 @@ from urllib.parse import urlparse ...@@ -22,6 +23,7 @@ from urllib.parse import urlparse
from urllib.request import urlopen from urllib.request import urlopen
import torch import torch
from tensorrt_llm.inputs import default_multimodal_input_loader
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.common.multimodal.image_loader import ImageLoader from dynamo.common.multimodal.image_loader import ImageLoader
...@@ -187,6 +189,11 @@ class MultimodalRequestProcessor: ...@@ -187,6 +189,11 @@ class MultimodalRequestProcessor:
"prompt": str, "prompt": str,
"prompt_token_ids": List[int] "prompt_token_ids": List[int]
} }
-----------------------------------------------------------------------------
TODO: Revert default_multimodal_input_loader calls having fixed TRT-LLM's
token IDs & MM data path in generate_async() for the embeddings case.
-----------------------------------------------------------------------------
""" """
self.previous_decoded_text = "" self.previous_decoded_text = ""
...@@ -214,18 +221,50 @@ class MultimodalRequestProcessor: ...@@ -214,18 +221,50 @@ class MultimodalRequestProcessor:
# mm_processor_kwargs must be a dict (not None) for TRT-LLM's processor # mm_processor_kwargs must be a dict (not None) for TRT-LLM's processor
processed_inputs = {"prompt_token_ids": token_ids, "mm_processor_kwargs": {}} processed_inputs = {"prompt_token_ids": token_ids, "mm_processor_kwargs": {}}
# EPD Flow Case 2: Embeddings received via NIXL from encode worker # The aforementioned fallback to default_multimodal_input_loader:
# The encode worker computed vision embeddings and transferred them via RDMA/NIXL messages = request.get("extra_args", {}).get(
# We need to pass these embeddings directly to TRT-LLM's generate_async "messages", request.get("messages", [])
if embeddings is not None: )
logging.info( text_prompt, _, embedding_paths = self.extract_prompt_and_media(messages)
f"Using NIXL embeddings from encoder: shape={embeddings.shape if hasattr(embeddings, 'shape') else 'N/A'}" loader_kwargs = {}
# Two cases, both for the default_multimodal_input_loader fallback:
# 1) EPD Flow Case 2: Embeddings received via NIXL from encode worker
# The encode worker computed vision embeddings and transferred them via RDMA/NIXL
# We need to pass these embeddings directly to TRT-LLM's generate_async
# 2) PD flow with no NIXL and no encoder
if embeddings is not None or embedding_paths:
if embeddings is not None:
logging.info(
f"Using NIXL embeddings from encoder: shape={embeddings.shape if hasattr(embeddings, 'shape') else 'N/A'}"
)
loader_kwargs["mm_embeddings"] = [embeddings]
elif embedding_paths:
# PD flow with no NIXL and no encoder
loader_kwargs["mm_embeddings"] = [
self.load_tensor_from_path_or_url(path) for path in embedding_paths
]
logging.info(f"Using embedding paths: {embedding_paths}")
# NOTE: default_multimodal_input_loader downloads images and preprocesses them
# synchronously. Wrap in asyncio.to_thread to allow concurrent image loading
# across multiple requests, improving throughput at high concurrency.
fallback_processed_inputs = await asyncio.to_thread(
lambda: default_multimodal_input_loader(
tokenizer=self.tokenizer,
model_dir=self.model_dir,
model_type=self.model_type,
modality=self.modality,
prompts=[text_prompt],
image_data_format="pt",
device="cuda",
**loader_kwargs,
)
) )
# Return the first processed input if available
# Structure embeddings in the format TRT-LLM's generate_async expects if fallback_processed_inputs:
processed_inputs["multi_modal_embeddings"] = embeddings return fallback_processed_inputs[0]
return None
return processed_inputs
# PD Flow: Pre-tokenized by Rust frontend with direct media loading # PD Flow: Pre-tokenized by Rust frontend with direct media loading
# TODO: Add frontend decoding support # TODO: Add frontend decoding support
...@@ -240,7 +279,6 @@ class MultimodalRequestProcessor: ...@@ -240,7 +279,6 @@ class MultimodalRequestProcessor:
if image_items and isinstance(image_items, list): if image_items and isinstance(image_items, list):
# Separate embedding paths from regular image URLs # Separate embedding paths from regular image URLs
# Items come from Rust in format: {"Url": "..."} or {"Decoded": ...} # Items come from Rust in format: {"Url": "..."} or {"Decoded": ...}
embedding_paths = []
image_urls = [] image_urls = []
for item in image_items: for item in image_items:
...@@ -261,9 +299,7 @@ class MultimodalRequestProcessor: ...@@ -261,9 +299,7 @@ class MultimodalRequestProcessor:
continue continue
# Check if this is an embedding file based on extension # Check if this is an embedding file based on extension
if url.endswith((".pt", ".pth", ".bin")): if not url.endswith((".pt", ".pth", ".bin")):
embedding_paths.append(url)
else:
# Keep original item format for load_image_batch # Keep original item format for load_image_batch
image_urls.append( image_urls.append(
item if isinstance(item, dict) else {"Url": item} item if isinstance(item, dict) else {"Url": item}
...@@ -285,23 +321,6 @@ class MultimodalRequestProcessor: ...@@ -285,23 +321,6 @@ class MultimodalRequestProcessor:
logging.error(f"Failed to load images: {e}") logging.error(f"Failed to load images: {e}")
return None return None
# Load embedding files (.pt, .pth, .bin) for PD flow
# These are pre-computed vision encoder outputs
if embedding_paths:
try:
loaded_embeddings = [
self.load_tensor_from_path_or_url(path)
for path in embedding_paths
]
if loaded_embeddings:
processed_mm_data["embedding"] = loaded_embeddings
logging.info(
f"Loaded {len(loaded_embeddings)} embedding file(s) from paths: {embedding_paths}"
)
except Exception as e:
logging.error(f"Failed to load embeddings: {e}")
return None
# TODO: Add support for video_url, audio_url # TODO: Add support for video_url, audio_url
if processed_mm_data: if processed_mm_data:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment