Unverified Commit a79122c6 authored by milesial's avatar milesial Committed by GitHub
Browse files

feat: Media URL passthrough in OAI preprocessor (#3733)


Signed-off-by: default avatarAlexandre Milesi <30204471+milesial@users.noreply.github.com>
parent a34f52cf
......@@ -13,9 +13,12 @@
pub mod prompt;
pub mod tools;
use anyhow::Context;
use anyhow::{Result, bail};
use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, EncodingFormat};
use dynamo_async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessageContent,
ChatCompletionRequestUserMessageContentPart, ChatCompletionToolChoiceOption, EncodingFormat,
};
use futures::Stream;
use futures::stream::{self, StreamExt};
use prompt::OAIPromptFormatter;
......@@ -24,7 +27,10 @@ use tracing;
use crate::model_card::{ModelDeploymentCard, ModelInfo};
use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
use crate::protocols::common::preprocessor::{
MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder,
};
use crate::tokenizers::Encoding;
use dynamo_parsers::{ReasoningParser, ReasoningParserType};
......@@ -168,8 +174,14 @@ impl OpenAIPreprocessor {
request: &R,
) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
let mut builder = self.builder(request)?;
let formatted_prompt = self.apply_template(request)?;
let annotations = self.gather_tokens(request, &mut builder, formatted_prompt)?;
let formatted_prompt = self
.apply_template(request)
.with_context(|| "Failed to apply prompt template")?;
let annotations = self
.gather_tokens(request, &mut builder, formatted_prompt)
.with_context(|| "Failed to gather tokens")?;
self.gather_multi_modal_data(request, &mut builder)
.with_context(|| "Failed to gather multimodal data")?;
Ok((builder.build()?, annotations))
}
......@@ -255,6 +267,57 @@ impl OpenAIPreprocessor {
}
}
pub fn gather_multi_modal_data<R: OAIChatLikeRequest>(
&self,
request: &R,
builder: &mut PreprocessedRequestBuilder,
) -> Result<()> {
let messages = request.messages();
let message_count = messages.len().unwrap_or(0);
let mut media_map: MultimodalDataMap = HashMap::new();
for idx in 0..message_count {
let msg = messages
.get_item_by_index(idx)
.map_err(|_| anyhow::Error::msg(format!("Cannot get message at index {idx}")))?;
let msg_json: serde_json::Value = serde_json::to_value(&msg)?;
let message: ChatCompletionRequestMessage = serde_json::from_value(msg_json)?;
let content_parts = match &message {
ChatCompletionRequestMessage::User(u) => match &u.content {
ChatCompletionRequestUserMessageContent::Array(parts) => parts,
_ => continue,
},
_ => continue,
};
// Iterate over content parts
for content_part in content_parts {
let (type_str, url) = match content_part {
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
("image_url".to_string(), image_part.image_url.url.clone())
}
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
("video_url".to_string(), video_part.video_url.url.clone())
}
ChatCompletionRequestUserMessageContentPart::AudioUrl(audio_part) => {
("audio_url".to_string(), audio_part.audio_url.url.clone())
}
_ => continue,
};
let map_item = media_map.entry(type_str.clone()).or_default();
map_item.push(MultimodalData::Url(url));
}
}
if !media_map.is_empty() {
builder.multi_modal_data(Some(media_map));
}
Ok(())
}
pub fn gather_tokens<
R: OAIChatLikeRequest
+ AnnotationsProvider
......@@ -789,7 +852,6 @@ impl
// forward the common completion request to the next operator
let response_stream = next.generate(common_request).await?;
// Extract context once
let context = response_stream.context();
......@@ -898,6 +960,8 @@ impl
// convert the chat completion request to a common completion request
let mut builder = self.builder(&request)?;
let annotations = self.gather_tokens(&request, &mut builder, None)?;
self.gather_multi_modal_data(&request, &mut builder)?;
let common_request = builder.build()?;
// update isl
......
......@@ -8,6 +8,15 @@ use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride;
use crate::protocols::TokenIdType;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum MultimodalData {
Url(url::Url),
// TODO: Decoded(DecodedMediaData),
}
// multimodal map containing {mm_part_type: [data...]}
pub type MultimodalDataMap = std::collections::HashMap<String, Vec<MultimodalData>>;
/// [`PreprocessedRequest`] is the internal representation of an LLM request. The [`dynamo.llm-preprocessor`]
/// crate is responsible for converting request from the public APIs to this internal representation.
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
......@@ -18,6 +27,10 @@ pub struct PreprocessedRequest {
/// Type of prompt
pub token_ids: Vec<TokenIdType>,
// Multimodal data
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub multi_modal_data: Option<MultimodalDataMap>,
/// StopConditions are conditions that the inference engine will use to stop generation.
pub stop_conditions: StopConditions,
......
......@@ -9,6 +9,7 @@ use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionReque
use serde::{Deserialize, Serialize};
use hf_hub::{Cache, Repo, RepoType, api::tokio::ApiBuilder};
use rstest::rstest;
use std::path::PathBuf;
......@@ -492,3 +493,97 @@ async fn test_multi_turn_with_continuation() {
insta::assert_snapshot!(formatted_prompt);
});
}
// Helper to build message with media chunks (single or mixed types)
fn build_message(text: &str, chunks: &[(&str, usize)]) -> String {
let mut content_parts = vec![format!(r#"{{"type": "text", "text": "{}"}}"#, text)];
for (chunk_type, count) in chunks {
for i in 1..=*count {
let chunk = match *chunk_type {
"image_url" => format!(
r#"{{"type": "image_url", "image_url": {{"url": "https://example.com/img{}.jpg"}}}}"#,
i
),
"video_url" => format!(
r#"{{"type": "video_url", "video_url": {{"url": "https://example.com/vid{}.mp4"}}}}"#,
i
),
"audio_url" => format!(
r#"{{"type": "audio_url", "audio_url": {{"url": "https://example.com/audio{}.mp3"}}}}"#,
i
),
_ => panic!("Unknown chunk type: {}", chunk_type),
};
content_parts.push(chunk);
}
}
format!(
r#"[{{"role": "user", "content": [{}]}}]"#,
content_parts.join(", ")
)
}
/// Test the preprocessor with multimodal data (single and mixed types) to verify gather_multi_modal_data code path
#[rstest]
// No media case
#[case::no_media(&[])]
// Single media item cases
#[case::single_video(&[("video_url", 1)])]
// Multiple media items of the same type
#[case::three_images(&[("image_url", 3)])]
// Mixed media types
#[case::mixed_multiple(&[("image_url", 2), ("video_url", 1), ("audio_url", 2)])]
#[tokio::test]
async fn test_media_url_passthrough(#[case] media_chunks: &[(&str, usize)]) {
if let Err(e) = get_hf_token() {
println!("HF_TOKEN is not set, skipping test: {}", e);
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let preprocessor = dynamo_llm::preprocessor::OpenAIPreprocessor::new(mdc.clone()).unwrap();
// Build the message with the specified media chunks
let message = build_message("Test multimodal content", media_chunks);
let request = Request::from(&message, None, None, mdc.slug().to_string());
let (preprocessed, _annotations) = preprocessor.preprocess_request(&request).unwrap();
// Verify multimodal data handling
if media_chunks.is_empty() {
// No media case - should be None or empty
assert!(
preprocessed.multi_modal_data.is_none()
|| preprocessed.multi_modal_data.as_ref().unwrap().is_empty(),
"Multimodal data should be None or empty when no media is present"
);
} else {
// Media present - should be captured
assert!(
preprocessed.multi_modal_data.is_some(),
"Multimodal data should be present"
);
let media_map = preprocessed.multi_modal_data.as_ref().unwrap();
// Check each media type and count
for (media_type, expected_count) in media_chunks {
assert!(
media_map.contains_key(*media_type),
"Should contain {} key",
media_type
);
assert_eq!(
media_map.get(*media_type).unwrap().len(),
*expected_count,
"Should have {} {} item(s)",
expected_count,
media_type
);
}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment