Unverified Commit 6767559f authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

fix: Support for msg[content] as a list (#4485)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
Signed-off-by: default avatarKrishnanPrash <140860868+KrishnanPrash@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 2f18b23e
...@@ -46,6 +46,8 @@ python -m dynamo.frontend --http-port=8000 & ...@@ -46,6 +46,8 @@ python -m dynamo.frontend --http-port=8000 &
EXTRA_ARGS="" EXTRA_ARGS=""
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048" EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
fi fi
# Start vLLM worker with vision model # Start vLLM worker with vision model
......
...@@ -106,6 +106,7 @@ struct HfTokenizerConfigJsonFormatter { ...@@ -106,6 +106,7 @@ struct HfTokenizerConfigJsonFormatter {
config: ChatTemplate, config: ChatTemplate,
mixins: Arc<ContextMixins>, mixins: Arc<ContextMixins>,
supports_add_generation_prompt: bool, supports_add_generation_prompt: bool,
requires_content_arrays: bool,
} }
// /// OpenAI Standard Prompt Formatter // /// OpenAI Standard Prompt Formatter
......
...@@ -6,9 +6,38 @@ use std::sync::Arc; ...@@ -6,9 +6,38 @@ use std::sync::Arc;
use super::tokcfg::{ChatTemplate, raise_exception, strftime_now, tojson}; use super::tokcfg::{ChatTemplate, raise_exception, strftime_now, tojson};
use super::{ContextMixins, HfTokenizerConfigJsonFormatter, JinjaEnvironment}; use super::{ContextMixins, HfTokenizerConfigJsonFormatter, JinjaEnvironment};
use either::Either; use either::Either;
use minijinja::{Environment, Value}; use minijinja::{Environment, Value, context};
use serde_json::json;
use tracing; use tracing;
/// Detects if a template requires content as arrays (multimodal) vs strings (text-only).
/// Returns true if the template only works with array format.
fn detect_content_array_usage(env: &Environment) -> bool {
// Test with array format
let array_msg = context! {
messages => json!([{"role": "user", "content": [{"type": "text", "text": "template_test"}]}]),
add_generation_prompt => false,
};
// Test with string format
let string_msg = context! {
messages => json!([{"role": "user", "content": "template_test"}]),
add_generation_prompt => false,
};
let out_array = env
.get_template("default")
.and_then(|t| t.render(&array_msg))
.unwrap_or_default();
let out_string = env
.get_template("default")
.and_then(|t| t.render(&string_msg))
.unwrap_or_default();
// If array works but string doesn't, template requires arrays
out_array.contains("template_test") && !out_string.contains("template_test")
}
/// Remove known non-standard Jinja2 tags from chat templates /// Remove known non-standard Jinja2 tags from chat templates
/// ///
/// Some models use custom Jinja2 extensions that minijinja doesn't recognize. These tags /// Some models use custom Jinja2 extensions that minijinja doesn't recognize. These tags
...@@ -120,11 +149,15 @@ impl HfTokenizerConfigJsonFormatter { ...@@ -120,11 +149,15 @@ impl HfTokenizerConfigJsonFormatter {
} }
} }
// Detect at model load time whether this template requires content arrays
let requires_content_arrays = detect_content_array_usage(&env);
Ok(HfTokenizerConfigJsonFormatter { Ok(HfTokenizerConfigJsonFormatter {
env, env,
config, config,
mixins: Arc::new(mixins), mixins: Arc::new(mixins),
supports_add_generation_prompt: supports_add_generation_prompt.unwrap_or(false), supports_add_generation_prompt: supports_add_generation_prompt.unwrap_or(false),
requires_content_arrays,
}) })
} }
} }
......
...@@ -73,10 +73,9 @@ fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> { ...@@ -73,10 +73,9 @@ fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
Some(Value::from_serialize(&updated_tools)) Some(Value::from_serialize(&updated_tools))
} }
fn may_be_fix_msg_content(messages: serde_json::Value) -> Value { fn may_be_fix_msg_content(messages: serde_json::Value, preserve_arrays: bool) -> Value {
// If messages[content] is provided as a list containing ONLY text parts, // preserve_arrays=true: strings → arrays (multimodal)
// concatenate them into a string to match chat template expectations. // preserve_arrays=false: text-only arrays → strings (standard)
// Mixed content types are left for chat templates to handle.
let Some(arr) = messages.as_array() else { let Some(arr) = messages.as_array() else {
return Value::from_serialize(&messages); return Value::from_serialize(&messages);
...@@ -86,7 +85,20 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value { ...@@ -86,7 +85,20 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
.iter() .iter()
.map(|msg| { .map(|msg| {
match msg.get("content") { match msg.get("content") {
Some(serde_json::Value::Array(content_array)) => { // Case 1: String to Array (for multimodal templates)
Some(serde_json::Value::String(text)) if preserve_arrays => {
let mut modified_msg = msg.clone();
if let Some(msg_object) = modified_msg.as_object_mut() {
let content_array = serde_json::json!([{
"type": "text",
"text": text
}]);
msg_object.insert("content".to_string(), content_array);
}
modified_msg
}
// Case 2: Array to String (for standard templates)
Some(serde_json::Value::Array(content_array)) if !preserve_arrays => {
let is_text_only_array = !content_array.is_empty() let is_text_only_array = !content_array.is_empty()
&& content_array.iter().all(|part| { && content_array.iter().all(|part| {
part.get("type") part.get("type")
...@@ -114,7 +126,7 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value { ...@@ -114,7 +126,7 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
msg.clone() // Mixed content or non-text only msg.clone() // Mixed content or non-text only
} }
} }
_ => msg.clone(), // String content or missing content - return unchanged _ => msg.clone(), // No conversion needed
} }
}) })
.collect(); .collect();
...@@ -159,20 +171,8 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest { ...@@ -159,20 +171,8 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn messages(&self) -> Value { fn messages(&self) -> Value {
let messages_json = serde_json::to_value(&self.inner.messages).unwrap(); let messages_json = serde_json::to_value(&self.inner.messages).unwrap();
let needs_fixing = if let Some(arr) = messages_json.as_array() {
arr.iter()
.any(|msg| msg.get("content").and_then(|c| c.as_array()).is_some())
} else {
false
};
if needs_fixing {
may_be_fix_msg_content(messages_json)
} else {
Value::from_serialize(&messages_json) Value::from_serialize(&messages_json)
} }
}
fn tools(&self) -> Option<Value> { fn tools(&self) -> Option<Value> {
if self.inner.tools.is_none() { if self.inner.tools.is_none() {
...@@ -301,6 +301,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { ...@@ -301,6 +301,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
let messages_canonical = req.messages(); let messages_canonical = req.messages();
let mut messages_for_template: serde_json::Value = let mut messages_for_template: serde_json::Value =
serde_json::to_value(&messages_canonical).unwrap(); serde_json::to_value(&messages_canonical).unwrap();
messages_for_template = serde_json::to_value(may_be_fix_msg_content(
messages_for_template,
self.requires_content_arrays,
))
.unwrap();
normalize_tool_arguments_in_messages(&mut messages_for_template); normalize_tool_arguments_in_messages(&mut messages_for_template);
let ctx = context! { let ctx = context! {
...@@ -457,7 +464,10 @@ mod tests { ...@@ -457,7 +464,10 @@ mod tests {
}"#; }"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap(); let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Test array → string normalization (preserve_arrays=false for standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: text-only array is concatenated into a single string // Verify: text-only array is concatenated into a single string
assert_eq!( assert_eq!(
...@@ -500,7 +510,10 @@ mod tests { ...@@ -500,7 +510,10 @@ mod tests {
}"#; }"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap(); let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Test array → string normalization (preserve_arrays=false for standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: System message with string content remains unchanged // Verify: System message with string content remains unchanged
assert_eq!( assert_eq!(
...@@ -541,7 +554,10 @@ mod tests { ...@@ -541,7 +554,10 @@ mod tests {
}"#; }"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap(); let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Empty arrays should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: Empty arrays are preserved as-is // Verify: Empty arrays are preserved as-is
assert!(messages[0]["content"].is_array()); assert!(messages[0]["content"].is_array());
...@@ -562,7 +578,10 @@ mod tests { ...@@ -562,7 +578,10 @@ mod tests {
}"#; }"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap(); let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Test with preserve_arrays=false (standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: String content is not modified // Verify: String content is not modified
assert_eq!( assert_eq!(
...@@ -589,7 +608,10 @@ mod tests { ...@@ -589,7 +608,10 @@ mod tests {
}"#; }"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap(); let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Mixed content should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: Mixed content types are preserved as array for template handling // Verify: Mixed content types are preserved as array for template handling
assert!(messages[0]["content"].is_array()); assert!(messages[0]["content"].is_array());
...@@ -617,7 +639,10 @@ mod tests { ...@@ -617,7 +639,10 @@ mod tests {
}"#; }"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap(); let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Non-text arrays should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: Non-text content arrays are preserved for template handling // Verify: Non-text content arrays are preserved for template handling
assert!(messages[0]["content"].is_array()); assert!(messages[0]["content"].is_array());
...@@ -713,7 +738,8 @@ NORMAL MODE ...@@ -713,7 +738,8 @@ NORMAL MODE
}"#; }"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap(); let messages_raw = serde_json::to_value(request.messages()).unwrap();
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Mixed types should preserve array structure // Mixed types should preserve array structure
assert!(messages[0]["content"].is_array()); assert!(messages[0]["content"].is_array());
...@@ -735,7 +761,8 @@ NORMAL MODE ...@@ -735,7 +761,8 @@ NORMAL MODE
}"#; }"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap(); let messages_raw = serde_json::to_value(request.messages()).unwrap();
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Unknown types mixed with text should preserve array // Unknown types mixed with text should preserve array
assert!(messages[0]["content"].is_array()); assert!(messages[0]["content"].is_array());
...@@ -873,11 +900,15 @@ NORMAL MODE ...@@ -873,11 +900,15 @@ NORMAL MODE
}"#; }"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let mut messages = serde_json::to_value(request.messages()).unwrap(); let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Apply content normalization with preserve_arrays=false (standard templates)
let mut messages =
serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
normalize_tool_arguments_in_messages(&mut messages); normalize_tool_arguments_in_messages(&mut messages);
// Multimodal content preserved as array // Multimodal content preserved as array (mixed types not flattened)
assert!(messages[0]["content"].is_array()); assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3); assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
...@@ -889,6 +920,63 @@ NORMAL MODE ...@@ -889,6 +920,63 @@ NORMAL MODE
); );
} }
/// Tests string → array normalization for multimodal templates
#[test]
fn test_may_be_fix_msg_content_string_to_array() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
]
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Test with preserve_arrays=true (multimodal templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, true)).unwrap();
// Verify: String is converted to array format
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 1);
assert_eq!(content_array[0]["type"], "text");
assert_eq!(content_array[0]["text"], "Hello, how are you?");
}
/// Tests that arrays are preserved when preserve_arrays=true
#[test]
fn test_may_be_fix_msg_content_array_preserved_with_multimodal() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "part 1"},
{"type": "text", "text": "part 2"}
]
}
]
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Test with preserve_arrays=true (multimodal templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, true)).unwrap();
// Verify: Array is preserved as-is
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 2);
assert_eq!(content_array[0]["text"], "part 1");
assert_eq!(content_array[1]["text"], "part 2");
}
fn user() -> Msg { fn user() -> Msg {
Msg::User(Default::default()) Msg::User(Default::default())
} }
......
...@@ -230,6 +230,42 @@ vllm_configs = { ...@@ -230,6 +230,42 @@ vllm_configs = {
), ),
], ],
), ),
"multimodal_agg_llava": VLLMConfig(
name="multimodal_agg_llava",
directory=vllm_dir,
script_name="agg_multimodal.sh",
marks=[
pytest.mark.gpu_2,
# https://github.com/ai-dynamo/dynamo/issues/4501
pytest.mark.xfail(strict=False),
],
model="llava-hf/llava-1.5-7b-hf",
script_args=["--model", "llava-hf/llava-1.5-7b-hf"],
delayed_start=0,
timeout=360,
request_payloads=[
# HTTP URL test
chat_payload(
[
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
},
},
],
repeat_count=1,
expected_response=["bus"],
temperature=0.0,
),
# String content test - verifies string → array conversion for multimodal templates
chat_payload_default(
repeat_count=1,
expected_response=[], # Just validate no error
),
],
),
# TODO: Update this test case when we have video multimodal support in vllm official components # TODO: Update this test case when we have video multimodal support in vllm official components
"multimodal_video_agg": VLLMConfig( "multimodal_video_agg": VLLMConfig(
name="multimodal_video_agg", name="multimodal_video_agg",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment