"lib/runtime/vscode:/vscode.git/clone" did not exist on "6fe2152bfca7bb1fa681eae957b252c6cd0b7a53"
Unverified Commit e6ddf0ea authored by Michal Guzek's avatar Michal Guzek Committed by GitHub
Browse files

fix: TRT-LLM multimodal preprocessor - remove default_multimodal_input_loader...


fix: TRT-LLM multimodal preprocessor - remove default_multimodal_input_loader from the embedding paths (#6924)
Signed-off-by: default avatarMichal Guzek <mguzek@nvidia.com>
parent a620a9cf
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import time
from io import BytesIO
......@@ -23,7 +22,6 @@ from urllib.parse import urlparse
from urllib.request import urlopen
import torch
from tensorrt_llm.inputs import default_multimodal_input_loader
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.common.multimodal.image_loader import ImageLoader
......@@ -195,10 +193,6 @@ class MultimodalRequestProcessor:
"prompt_token_ids": List[int]
}
-----------------------------------------------------------------------------
TODO: Revert default_multimodal_input_loader calls having fixed TRT-LLM's
token IDs & MM data path in generate_async() for the embeddings case.
-----------------------------------------------------------------------------
"""
self.previous_decoded_text = ""
......@@ -216,60 +210,35 @@ class MultimodalRequestProcessor:
logging.warning("MM: No prompt_token_ids from encoder")
return result
# Get token_ids from request (already tokenized by Rust frontend)
token_ids = request.get("token_ids")
if not token_ids:
logging.warning("No token_ids in request")
return None
# Initialize result in TokensPrompt format
# mm_processor_kwargs must be a dict (not None) for TRT-LLM's processor
processed_inputs = {"prompt_token_ids": token_ids, "mm_processor_kwargs": {}}
processed_inputs: Dict[str, Any] = {"mm_processor_kwargs": {}}
# TODO(TRTLLM-11294): Remove the fallback to text_prompt for EPD-NIXL and embeddings cases.
# This is a temporary workaround to bypass TRT-LLM's bug where token IDs & embeddings
# are not processed correctly.
extra_args = request.get("extra_args") or {}
formatted_prompt_from_frontend = extra_args.get("formatted_prompt")
# EPD Flow Case 2: Embeddings received via NIXL from encode worker
# The encode worker computed vision embeddings and transferred them via RDMA/NIXL
# We need to pass these embeddings directly to TRT-LLM's generate_async
if embeddings is not None:
logging.info(
f"Using NIXL embeddings from encoder: shape={embeddings.shape if hasattr(embeddings, 'shape') else 'N/A'}"
)
# The aforementioned fallback to default_multimodal_input_loader:
messages = request.get("extra_args", {}).get(
"messages", request.get("messages", [])
)
text_prompt, _, embedding_paths = self.extract_prompt_and_media(messages)
loader_kwargs = {}
# Two cases, both for the default_multimodal_input_loader fallback:
# 1) EPD Flow Case 2: Embeddings received via NIXL from encode worker
# The encode worker computed vision embeddings and transferred them via RDMA/NIXL
# We need to pass these embeddings directly to TRT-LLM's generate_async
# 2) PD flow with no NIXL and no encoder
if embeddings is not None or embedding_paths:
if embeddings is not None:
logging.info(
f"Using NIXL embeddings from encoder: shape={embeddings.shape if hasattr(embeddings, 'shape') else 'N/A'}"
)
loader_kwargs["mm_embeddings"] = [embeddings]
elif embedding_paths:
# PD flow with no NIXL and no encoder
loader_kwargs["mm_embeddings"] = [
self.load_tensor_from_path_or_url(path) for path in embedding_paths
]
logging.info(f"Using embedding paths: {embedding_paths}")
# NOTE: default_multimodal_input_loader downloads images and preprocesses them
# synchronously. Wrap in asyncio.to_thread to allow concurrent image loading
# across multiple requests, improving throughput at high concurrency.
fallback_processed_inputs = await asyncio.to_thread(
lambda: default_multimodal_input_loader(
tokenizer=self.tokenizer,
model_dir=self.model_dir,
model_type=self.model_type,
modality=self.modality,
prompts=[text_prompt],
image_data_format="pt",
device="cuda",
**loader_kwargs,
)
# Same structure as PD flow (TRT-LLM expects dict with "image" key)
image_embeddings = (
embeddings if isinstance(embeddings, list) else [embeddings]
)
# Return the first processed input if available
if fallback_processed_inputs:
return fallback_processed_inputs[0]
return None
processed_inputs["multi_modal_embeddings"] = {"image": image_embeddings}
if formatted_prompt_from_frontend:
processed_inputs["prompt"] = formatted_prompt_from_frontend
else:
logging.warning("No formatted prompt from frontend")
return None
return processed_inputs
# PD Flow: Pre-tokenized by Rust frontend with direct media loading
# TODO: Add frontend decoding support
......@@ -278,12 +247,14 @@ class MultimodalRequestProcessor:
multi_modal_data = request.get("multi_modal_data")
if multi_modal_data and isinstance(multi_modal_data, dict):
processed_mm_data = {}
loaded_embeddings = []
# Process images and embedding paths from image_url field
image_items = multi_modal_data.get("image_url", [])
if image_items and isinstance(image_items, list):
# Separate embedding paths from regular image URLs
# Items come from Rust in format: {"Url": "..."} or {"Decoded": ...}
embedding_paths = []
image_urls = []
for item in image_items:
......@@ -304,7 +275,9 @@ class MultimodalRequestProcessor:
continue
# Check if this is an embedding file based on extension
if not url.endswith((".pt", ".pth", ".bin")):
if url.endswith((".pt", ".pth", ".bin")):
embedding_paths.append(url)
else:
# Keep original item format for load_image_batch
image_urls.append(
item if isinstance(item, dict) else {"Url": item}
......@@ -326,11 +299,48 @@ class MultimodalRequestProcessor:
logging.error(f"Failed to load images: {e}")
return None
# Load embedding files (.pt, .pth, .bin) for PD flow
# These are pre-computed vision encoder outputs
if embedding_paths:
try:
loaded_embeddings = [
self.load_tensor_from_path_or_url(path)
for path in embedding_paths
]
if loaded_embeddings:
logging.info(
f"Loaded {len(loaded_embeddings)} embedding file(s) from paths: {embedding_paths}"
)
except Exception as e:
logging.error(f"Failed to load embeddings: {e}")
return None
# TODO: Add support for video_url, audio_url
if loaded_embeddings:
# For TRT-LLM MM embeddings, the currently
# supported modality is "image".
if formatted_prompt_from_frontend:
processed_inputs["prompt"] = formatted_prompt_from_frontend
else:
logging.warning("No formatted prompt from frontend")
return None
processed_inputs["multi_modal_embeddings"] = {
"image": loaded_embeddings
}
return processed_inputs
if processed_mm_data:
processed_inputs["multi_modal_data"] = processed_mm_data
# Get token_ids from request (already tokenized by Rust frontend)
token_ids = request.get("token_ids")
if not token_ids:
logging.warning("No token_ids in request")
return None
processed_inputs["prompt_token_ids"] = token_ids
return processed_inputs
def create_response_chunk(
......
......@@ -231,9 +231,9 @@ impl OpenAIPreprocessor {
.apply_template(request)
.with_context(|| "Failed to apply prompt template")?;
let annotations = self
.gather_tokens(request, &mut builder, formatted_prompt, tracker)
.gather_tokens(request, &mut builder, formatted_prompt.clone(), tracker)
.with_context(|| "Failed to gather tokens")?;
self.gather_multi_modal_data(request, &mut builder)
self.gather_multi_modal_data(request, &mut builder, formatted_prompt)
.await
.with_context(|| "Failed to gather multimodal data")?;
......@@ -351,6 +351,7 @@ impl OpenAIPreprocessor {
&self,
request: &R,
builder: &mut PreprocessedRequestBuilder,
formatted_prompt: Option<String>,
) -> Result<()> {
let mut media_map: MultimodalDataMap = HashMap::new();
let mut fetch_tasks: Vec<(String, ChatCompletionRequestUserMessageContentPart)> =
......@@ -417,12 +418,16 @@ impl OpenAIPreprocessor {
if !media_map.is_empty() {
builder.multi_modal_data(Some(media_map));
// Preserve original messages in extra_args for multimodal workers that need them
// (e.g., TRT-LLM multimodal processor needs raw messages for proper tokenization)
// Preserve original messages and formatted prompt in extra_args for multimodal
// workers (e.g., TRT-LLM needs messages and the template-rendered prompt with
// <image> placeholders for embedding-path / NIXL flows).
let messages_json = serde_json::to_value(request.messages())?;
let extra_args = serde_json::json!({
let mut extra_args = serde_json::json!({
"messages": messages_json
});
if let Some(ref prompt) = formatted_prompt {
extra_args["formatted_prompt"] = serde_json::Value::String(prompt.clone());
}
builder.extra_args(Some(extra_args));
}
......@@ -1335,7 +1340,8 @@ impl
};
// Gather multimodal data (works with both embeddings and text prompts)
self.gather_multi_modal_data(&request, &mut builder).await?;
self.gather_multi_modal_data(&request, &mut builder, None)
.await?;
let mut common_request = builder.build()?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment