Unverified Commit 6767559f authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

fix: Support for msg[content] as a list (#4485)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
Signed-off-by: default avatarKrishnanPrash <140860868+KrishnanPrash@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 2f18b23e
......@@ -46,6 +46,8 @@ python -m dynamo.frontend --http-port=8000 &
EXTRA_ARGS=""
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
fi
# Start vLLM worker with vision model
......
......@@ -106,6 +106,7 @@ struct HfTokenizerConfigJsonFormatter {
config: ChatTemplate,
mixins: Arc<ContextMixins>,
supports_add_generation_prompt: bool,
requires_content_arrays: bool,
}
// /// OpenAI Standard Prompt Formatter
......
......@@ -6,9 +6,38 @@ use std::sync::Arc;
use super::tokcfg::{ChatTemplate, raise_exception, strftime_now, tojson};
use super::{ContextMixins, HfTokenizerConfigJsonFormatter, JinjaEnvironment};
use either::Either;
use minijinja::{Environment, Value};
use minijinja::{Environment, Value, context};
use serde_json::json;
use tracing;
/// Detects if a template requires content as arrays (multimodal) vs strings (text-only).
/// Returns true if the template only works with array format.
fn detect_content_array_usage(env: &Environment) -> bool {
// Test with array format
let array_msg = context! {
messages => json!([{"role": "user", "content": [{"type": "text", "text": "template_test"}]}]),
add_generation_prompt => false,
};
// Test with string format
let string_msg = context! {
messages => json!([{"role": "user", "content": "template_test"}]),
add_generation_prompt => false,
};
let out_array = env
.get_template("default")
.and_then(|t| t.render(&array_msg))
.unwrap_or_default();
let out_string = env
.get_template("default")
.and_then(|t| t.render(&string_msg))
.unwrap_or_default();
// If array works but string doesn't, template requires arrays
out_array.contains("template_test") && !out_string.contains("template_test")
}
/// Remove known non-standard Jinja2 tags from chat templates
///
/// Some models use custom Jinja2 extensions that minijinja doesn't recognize. These tags
......@@ -120,11 +149,15 @@ impl HfTokenizerConfigJsonFormatter {
}
}
// Detect at model load time whether this template requires content arrays
let requires_content_arrays = detect_content_array_usage(&env);
Ok(HfTokenizerConfigJsonFormatter {
env,
config,
mixins: Arc::new(mixins),
supports_add_generation_prompt: supports_add_generation_prompt.unwrap_or(false),
requires_content_arrays,
})
}
}
......
......@@ -73,10 +73,9 @@ fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
Some(Value::from_serialize(&updated_tools))
}
fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
// If messages[content] is provided as a list containing ONLY text parts,
// concatenate them into a string to match chat template expectations.
// Mixed content types are left for chat templates to handle.
fn may_be_fix_msg_content(messages: serde_json::Value, preserve_arrays: bool) -> Value {
// preserve_arrays=true: strings → arrays (multimodal)
// preserve_arrays=false: text-only arrays → strings (standard)
let Some(arr) = messages.as_array() else {
return Value::from_serialize(&messages);
......@@ -86,7 +85,20 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
.iter()
.map(|msg| {
match msg.get("content") {
Some(serde_json::Value::Array(content_array)) => {
// Case 1: String to Array (for multimodal templates)
Some(serde_json::Value::String(text)) if preserve_arrays => {
let mut modified_msg = msg.clone();
if let Some(msg_object) = modified_msg.as_object_mut() {
let content_array = serde_json::json!([{
"type": "text",
"text": text
}]);
msg_object.insert("content".to_string(), content_array);
}
modified_msg
}
// Case 2: Array to String (for standard templates)
Some(serde_json::Value::Array(content_array)) if !preserve_arrays => {
let is_text_only_array = !content_array.is_empty()
&& content_array.iter().all(|part| {
part.get("type")
......@@ -114,7 +126,7 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
msg.clone() // Mixed content or non-text only
}
}
_ => msg.clone(), // String content or missing content - return unchanged
_ => msg.clone(), // No conversion needed
}
})
.collect();
......@@ -159,20 +171,8 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn messages(&self) -> Value {
let messages_json = serde_json::to_value(&self.inner.messages).unwrap();
let needs_fixing = if let Some(arr) = messages_json.as_array() {
arr.iter()
.any(|msg| msg.get("content").and_then(|c| c.as_array()).is_some())
} else {
false
};
if needs_fixing {
may_be_fix_msg_content(messages_json)
} else {
Value::from_serialize(&messages_json)
}
}
fn tools(&self) -> Option<Value> {
if self.inner.tools.is_none() {
......@@ -301,6 +301,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
let messages_canonical = req.messages();
let mut messages_for_template: serde_json::Value =
serde_json::to_value(&messages_canonical).unwrap();
messages_for_template = serde_json::to_value(may_be_fix_msg_content(
messages_for_template,
self.requires_content_arrays,
))
.unwrap();
normalize_tool_arguments_in_messages(&mut messages_for_template);
let ctx = context! {
......@@ -457,7 +464,10 @@ mod tests {
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Test array → string normalization (preserve_arrays=false for standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: text-only array is concatenated into a single string
assert_eq!(
......@@ -500,7 +510,10 @@ mod tests {
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Test array → string normalization (preserve_arrays=false for standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: System message with string content remains unchanged
assert_eq!(
......@@ -541,7 +554,10 @@ mod tests {
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Empty arrays should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: Empty arrays are preserved as-is
assert!(messages[0]["content"].is_array());
......@@ -562,7 +578,10 @@ mod tests {
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Test with preserve_arrays=false (standard templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: String content is not modified
assert_eq!(
......@@ -589,7 +608,10 @@ mod tests {
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Mixed content should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: Mixed content types are preserved as array for template handling
assert!(messages[0]["content"].is_array());
......@@ -617,7 +639,10 @@ mod tests {
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Non-text arrays should be preserved regardless of preserve_arrays setting
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Verify: Non-text content arrays are preserved for template handling
assert!(messages[0]["content"].is_array());
......@@ -713,7 +738,8 @@ NORMAL MODE
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Mixed types should preserve array structure
assert!(messages[0]["content"].is_array());
......@@ -735,7 +761,8 @@ NORMAL MODE
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
// Unknown types mixed with text should preserve array
assert!(messages[0]["content"].is_array());
......@@ -873,11 +900,15 @@ NORMAL MODE
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let mut messages = serde_json::to_value(request.messages()).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Apply content normalization with preserve_arrays=false (standard templates)
let mut messages =
serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
normalize_tool_arguments_in_messages(&mut messages);
// Multimodal content preserved as array
// Multimodal content preserved as array (mixed types not flattened)
assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
......@@ -889,6 +920,63 @@ NORMAL MODE
);
}
/// Tests string → array normalization for multimodal templates
#[test]
fn test_may_be_fix_msg_content_string_to_array() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
]
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Test with preserve_arrays=true (multimodal templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, true)).unwrap();
// Verify: String is converted to array format
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 1);
assert_eq!(content_array[0]["type"], "text");
assert_eq!(content_array[0]["text"], "Hello, how are you?");
}
/// Tests that arrays are preserved when preserve_arrays=true
#[test]
fn test_may_be_fix_msg_content_array_preserved_with_multimodal() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "part 1"},
{"type": "text", "text": "part 2"}
]
}
]
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages_raw = serde_json::to_value(request.messages()).unwrap();
// Test with preserve_arrays=true (multimodal templates)
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, true)).unwrap();
// Verify: Array is preserved as-is
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 2);
assert_eq!(content_array[0]["text"], "part 1");
assert_eq!(content_array[1]["text"], "part 2");
}
fn user() -> Msg {
Msg::User(Default::default())
}
......
......@@ -230,6 +230,42 @@ vllm_configs = {
),
],
),
"multimodal_agg_llava": VLLMConfig(
name="multimodal_agg_llava",
directory=vllm_dir,
script_name="agg_multimodal.sh",
marks=[
pytest.mark.gpu_2,
# https://github.com/ai-dynamo/dynamo/issues/4501
pytest.mark.xfail(strict=False),
],
model="llava-hf/llava-1.5-7b-hf",
script_args=["--model", "llava-hf/llava-1.5-7b-hf"],
delayed_start=0,
timeout=360,
request_payloads=[
# HTTP URL test
chat_payload(
[
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
},
},
],
repeat_count=1,
expected_response=["bus"],
temperature=0.0,
),
# String content test - verifies string → array conversion for multimodal templates
chat_payload_default(
repeat_count=1,
expected_response=[], # Just validate no error
),
],
),
# TODO: Update this test case when we have video multimodal support in vllm official components
"multimodal_video_agg": VLLMConfig(
name="multimodal_video_agg",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment