Unverified Commit e6ddf0ea authored by Michal Guzek's avatar Michal Guzek Committed by GitHub
Browse files

fix: TRT-LLM multimodal preprocessor - remove default_multimodal_input_loader...


fix: TRT-LLM multimodal preprocessor - remove default_multimodal_input_loader from the embedding paths (#6924)
Signed-off-by: default avatarMichal Guzek <mguzek@nvidia.com>
parent a620a9cf
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging import logging
import time import time
from io import BytesIO from io import BytesIO
...@@ -23,7 +22,6 @@ from urllib.parse import urlparse ...@@ -23,7 +22,6 @@ from urllib.parse import urlparse
from urllib.request import urlopen from urllib.request import urlopen
import torch import torch
from tensorrt_llm.inputs import default_multimodal_input_loader
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.common.multimodal.image_loader import ImageLoader from dynamo.common.multimodal.image_loader import ImageLoader
...@@ -195,10 +193,6 @@ class MultimodalRequestProcessor: ...@@ -195,10 +193,6 @@ class MultimodalRequestProcessor:
"prompt_token_ids": List[int] "prompt_token_ids": List[int]
} }
-----------------------------------------------------------------------------
TODO: Revert default_multimodal_input_loader calls having fixed TRT-LLM's
token IDs & MM data path in generate_async() for the embeddings case.
-----------------------------------------------------------------------------
""" """
self.previous_decoded_text = "" self.previous_decoded_text = ""
...@@ -216,60 +210,35 @@ class MultimodalRequestProcessor: ...@@ -216,60 +210,35 @@ class MultimodalRequestProcessor:
logging.warning("MM: No prompt_token_ids from encoder") logging.warning("MM: No prompt_token_ids from encoder")
return result return result
# Get token_ids from request (already tokenized by Rust frontend)
token_ids = request.get("token_ids")
if not token_ids:
logging.warning("No token_ids in request")
return None
# Initialize result in TokensPrompt format # Initialize result in TokensPrompt format
# mm_processor_kwargs must be a dict (not None) for TRT-LLM's processor # mm_processor_kwargs must be a dict (not None) for TRT-LLM's processor
processed_inputs = {"prompt_token_ids": token_ids, "mm_processor_kwargs": {}} processed_inputs: Dict[str, Any] = {"mm_processor_kwargs": {}}
# The aforementioned fallback to default_multimodal_input_loader: # TODO(TRTLLM-11294): Remove the fallback to text_prompt for EPD-NIXL and embeddings cases.
messages = request.get("extra_args", {}).get( # This is a temporary workaround to bypass TRT-LLM's bug where token IDs & embeddings
"messages", request.get("messages", []) # are not processed correctly.
) extra_args = request.get("extra_args") or {}
text_prompt, _, embedding_paths = self.extract_prompt_and_media(messages) formatted_prompt_from_frontend = extra_args.get("formatted_prompt")
loader_kwargs = {}
# Two cases, both for the default_multimodal_input_loader fallback: # EPD Flow Case 2: Embeddings received via NIXL from encode worker
# 1) EPD Flow Case 2: Embeddings received via NIXL from encode worker
# The encode worker computed vision embeddings and transferred them via RDMA/NIXL # The encode worker computed vision embeddings and transferred them via RDMA/NIXL
# We need to pass these embeddings directly to TRT-LLM's generate_async # We need to pass these embeddings directly to TRT-LLM's generate_async
# 2) PD flow with no NIXL and no encoder
if embeddings is not None or embedding_paths:
if embeddings is not None: if embeddings is not None:
logging.info( logging.info(
f"Using NIXL embeddings from encoder: shape={embeddings.shape if hasattr(embeddings, 'shape') else 'N/A'}" f"Using NIXL embeddings from encoder: shape={embeddings.shape if hasattr(embeddings, 'shape') else 'N/A'}"
) )
loader_kwargs["mm_embeddings"] = [embeddings]
elif embedding_paths: # Same structure as PD flow (TRT-LLM expects dict with "image" key)
# PD flow with no NIXL and no encoder image_embeddings = (
loader_kwargs["mm_embeddings"] = [ embeddings if isinstance(embeddings, list) else [embeddings]
self.load_tensor_from_path_or_url(path) for path in embedding_paths
]
logging.info(f"Using embedding paths: {embedding_paths}")
# NOTE: default_multimodal_input_loader downloads images and preprocesses them
# synchronously. Wrap in asyncio.to_thread to allow concurrent image loading
# across multiple requests, improving throughput at high concurrency.
fallback_processed_inputs = await asyncio.to_thread(
lambda: default_multimodal_input_loader(
tokenizer=self.tokenizer,
model_dir=self.model_dir,
model_type=self.model_type,
modality=self.modality,
prompts=[text_prompt],
image_data_format="pt",
device="cuda",
**loader_kwargs,
)
) )
# Return the first processed input if available processed_inputs["multi_modal_embeddings"] = {"image": image_embeddings}
if fallback_processed_inputs: if formatted_prompt_from_frontend:
return fallback_processed_inputs[0] processed_inputs["prompt"] = formatted_prompt_from_frontend
else:
logging.warning("No formatted prompt from frontend")
return None return None
return processed_inputs
# PD Flow: Pre-tokenized by Rust frontend with direct media loading # PD Flow: Pre-tokenized by Rust frontend with direct media loading
# TODO: Add frontend decoding support # TODO: Add frontend decoding support
...@@ -278,12 +247,14 @@ class MultimodalRequestProcessor: ...@@ -278,12 +247,14 @@ class MultimodalRequestProcessor:
multi_modal_data = request.get("multi_modal_data") multi_modal_data = request.get("multi_modal_data")
if multi_modal_data and isinstance(multi_modal_data, dict): if multi_modal_data and isinstance(multi_modal_data, dict):
processed_mm_data = {} processed_mm_data = {}
loaded_embeddings = []
# Process images and embedding paths from image_url field # Process images and embedding paths from image_url field
image_items = multi_modal_data.get("image_url", []) image_items = multi_modal_data.get("image_url", [])
if image_items and isinstance(image_items, list): if image_items and isinstance(image_items, list):
# Separate embedding paths from regular image URLs # Separate embedding paths from regular image URLs
# Items come from Rust in format: {"Url": "..."} or {"Decoded": ...} # Items come from Rust in format: {"Url": "..."} or {"Decoded": ...}
embedding_paths = []
image_urls = [] image_urls = []
for item in image_items: for item in image_items:
...@@ -304,7 +275,9 @@ class MultimodalRequestProcessor: ...@@ -304,7 +275,9 @@ class MultimodalRequestProcessor:
continue continue
# Check if this is an embedding file based on extension # Check if this is an embedding file based on extension
if not url.endswith((".pt", ".pth", ".bin")): if url.endswith((".pt", ".pth", ".bin")):
embedding_paths.append(url)
else:
# Keep original item format for load_image_batch # Keep original item format for load_image_batch
image_urls.append( image_urls.append(
item if isinstance(item, dict) else {"Url": item} item if isinstance(item, dict) else {"Url": item}
...@@ -326,11 +299,48 @@ class MultimodalRequestProcessor: ...@@ -326,11 +299,48 @@ class MultimodalRequestProcessor:
logging.error(f"Failed to load images: {e}") logging.error(f"Failed to load images: {e}")
return None return None
# Load embedding files (.pt, .pth, .bin) for PD flow
# These are pre-computed vision encoder outputs
if embedding_paths:
try:
loaded_embeddings = [
self.load_tensor_from_path_or_url(path)
for path in embedding_paths
]
if loaded_embeddings:
logging.info(
f"Loaded {len(loaded_embeddings)} embedding file(s) from paths: {embedding_paths}"
)
except Exception as e:
logging.error(f"Failed to load embeddings: {e}")
return None
# TODO: Add support for video_url, audio_url # TODO: Add support for video_url, audio_url
if loaded_embeddings:
# For TRT-LLM MM embeddings, the currently
# supported modality is "image".
if formatted_prompt_from_frontend:
processed_inputs["prompt"] = formatted_prompt_from_frontend
else:
logging.warning("No formatted prompt from frontend")
return None
processed_inputs["multi_modal_embeddings"] = {
"image": loaded_embeddings
}
return processed_inputs
if processed_mm_data: if processed_mm_data:
processed_inputs["multi_modal_data"] = processed_mm_data processed_inputs["multi_modal_data"] = processed_mm_data
# Get token_ids from request (already tokenized by Rust frontend)
token_ids = request.get("token_ids")
if not token_ids:
logging.warning("No token_ids in request")
return None
processed_inputs["prompt_token_ids"] = token_ids
return processed_inputs return processed_inputs
def create_response_chunk( def create_response_chunk(
......
...@@ -231,9 +231,9 @@ impl OpenAIPreprocessor { ...@@ -231,9 +231,9 @@ impl OpenAIPreprocessor {
.apply_template(request) .apply_template(request)
.with_context(|| "Failed to apply prompt template")?; .with_context(|| "Failed to apply prompt template")?;
let annotations = self let annotations = self
.gather_tokens(request, &mut builder, formatted_prompt, tracker) .gather_tokens(request, &mut builder, formatted_prompt.clone(), tracker)
.with_context(|| "Failed to gather tokens")?; .with_context(|| "Failed to gather tokens")?;
self.gather_multi_modal_data(request, &mut builder) self.gather_multi_modal_data(request, &mut builder, formatted_prompt)
.await .await
.with_context(|| "Failed to gather multimodal data")?; .with_context(|| "Failed to gather multimodal data")?;
...@@ -351,6 +351,7 @@ impl OpenAIPreprocessor { ...@@ -351,6 +351,7 @@ impl OpenAIPreprocessor {
&self, &self,
request: &R, request: &R,
builder: &mut PreprocessedRequestBuilder, builder: &mut PreprocessedRequestBuilder,
formatted_prompt: Option<String>,
) -> Result<()> { ) -> Result<()> {
let mut media_map: MultimodalDataMap = HashMap::new(); let mut media_map: MultimodalDataMap = HashMap::new();
let mut fetch_tasks: Vec<(String, ChatCompletionRequestUserMessageContentPart)> = let mut fetch_tasks: Vec<(String, ChatCompletionRequestUserMessageContentPart)> =
...@@ -417,12 +418,16 @@ impl OpenAIPreprocessor { ...@@ -417,12 +418,16 @@ impl OpenAIPreprocessor {
if !media_map.is_empty() { if !media_map.is_empty() {
builder.multi_modal_data(Some(media_map)); builder.multi_modal_data(Some(media_map));
// Preserve original messages in extra_args for multimodal workers that need them // Preserve original messages and formatted prompt in extra_args for multimodal
// (e.g., TRT-LLM multimodal processor needs raw messages for proper tokenization) // workers (e.g., TRT-LLM needs messages and the template-rendered prompt with
// <image> placeholders for embedding-path / NIXL flows).
let messages_json = serde_json::to_value(request.messages())?; let messages_json = serde_json::to_value(request.messages())?;
let extra_args = serde_json::json!({ let mut extra_args = serde_json::json!({
"messages": messages_json "messages": messages_json
}); });
if let Some(ref prompt) = formatted_prompt {
extra_args["formatted_prompt"] = serde_json::Value::String(prompt.clone());
}
builder.extra_args(Some(extra_args)); builder.extra_args(Some(extra_args));
} }
...@@ -1335,7 +1340,8 @@ impl ...@@ -1335,7 +1340,8 @@ impl
}; };
// Gather multimodal data (works with both embeddings and text prompts) // Gather multimodal data (works with both embeddings and text prompts)
self.gather_multi_modal_data(&request, &mut builder).await?; self.gather_multi_modal_data(&request, &mut builder, None)
.await?;
let mut common_request = builder.build()?; let mut common_request = builder.build()?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment