Unverified Commit 51c4fe63 authored by ryan-lempka's avatar ryan-lempka Committed by GitHub
Browse files

fix: deserialize tool call args (#4176)


Signed-off-by: default avatarRyan Lempka <rlempka@nvidia.com>
parent 441473c3
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
use super::*; use super::*;
use minijinja::{context, value::Value}; use minijinja::{context, value::Value};
use std::result::Result::Ok;
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest, chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
...@@ -121,6 +122,36 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value { ...@@ -121,6 +122,36 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
Value::from_serialize(&updated_messages) Value::from_serialize(&updated_messages)
} }
fn normalize_tool_arguments_in_messages(messages: &mut serde_json::Value) {
// Deserialize tool call arguments from JSON strings to objects/arrays before template rendering
// avoids double encoding and enables iteration
let Some(msgs) = messages.as_array_mut() else {
return;
};
for msg in msgs.iter_mut() {
if let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|v| v.as_array_mut()) {
for tc in tool_calls {
if let Some(function) = tc.get_mut("function").and_then(|v| v.as_object_mut())
&& let Some(args) = function.get_mut("arguments")
&& let Some(s) = args.as_str()
&& let Ok(parsed) = serde_json::from_str(s)
{
*args = parsed;
}
}
}
if let Some(function_call) = msg.get_mut("function_call").and_then(|v| v.as_object_mut())
&& let Some(args) = function_call.get_mut("arguments")
&& let Some(s) = args.as_str()
&& let Ok(parsed) = serde_json::from_str(s)
{
*args = parsed;
}
}
}
impl OAIChatLikeRequest for NvCreateChatCompletionRequest { impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn model(&self) -> String { fn model(&self) -> String {
self.inner.model.clone() self.inner.model.clone()
...@@ -267,8 +298,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { ...@@ -267,8 +298,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
add_generation_prompt add_generation_prompt
); );
let messages_canonical = req.messages();
let mut messages_for_template: serde_json::Value =
serde_json::to_value(&messages_canonical).unwrap();
normalize_tool_arguments_in_messages(&mut messages_for_template);
let ctx = context! { let ctx = context! {
messages => req.messages(), messages => messages_for_template,
tools => tools, tools => tools,
bos_token => self.config.bos_tok(), bos_token => self.config.bos_tok(),
eos_token => self.config.eos_tok(), eos_token => self.config.eos_tok(),
...@@ -298,6 +334,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { ...@@ -298,6 +334,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
mod tests { mod tests {
use super::*; use super::*;
use dynamo_async_openai::types::ChatCompletionRequestMessage as Msg; use dynamo_async_openai::types::ChatCompletionRequestMessage as Msg;
use minijinja::{Environment, context};
#[test] #[test]
fn test_may_be_fix_tool_schema_missing_type_and_properties() { fn test_may_be_fix_tool_schema_missing_type_and_properties() {
...@@ -705,6 +742,153 @@ NORMAL MODE ...@@ -705,6 +742,153 @@ NORMAL MODE
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3); assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
} }
#[test]
fn test_normalize_tool_arguments_tojson() {
let tmpl = r#"{{ messages[0].tool_calls[0].function.arguments | tojson }}"#;
// Message with tool_calls containing JSON string arguments
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
"role": "assistant",
"tool_calls": [{
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\"format\":\"celsius\",\"location\":\"San Francisco, CA\"}"
}
}]
})]);
normalize_tool_arguments_in_messages(&mut messages);
let mut env = Environment::new();
env.add_filter("tojson", super::super::tokcfg::tojson);
env.add_template("t", tmpl).unwrap();
let out = env
.get_template("t")
.unwrap()
.render(context! { messages => messages.as_array().unwrap() })
.unwrap();
// Should produce clean JSON without double-encoding
assert_eq!(
out,
r#"{"format":"celsius","location":"San Francisco, CA"}"#
);
}
#[test]
fn test_normalize_tool_arguments_items_loop() {
let tmpl = r#"{% for k, v in messages[0].tool_calls[0].function.arguments|items %}{{k}}={{v}};{% endfor %}"#;
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
"role": "assistant",
"tool_calls": [{
"type": "function",
"function": {
"name": "f",
"arguments": "{\"a\":1,\"b\":\"x\"}"
}
}]
})]);
normalize_tool_arguments_in_messages(&mut messages);
let mut env = Environment::new();
env.add_template("t", tmpl).unwrap();
let out = env
.get_template("t")
.unwrap()
.render(context! { messages => messages.as_array().unwrap() })
.unwrap();
assert!(out == "a=1;b=x;" || out == "b=x;a=1;");
}
#[test]
fn test_normalize_tool_arguments_legacy_function_call() {
// Test deprecated function_call format (OpenAI compat)
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
"role": "assistant",
"function_call": {
"name": "get_weather",
"arguments": "{\"location\":\"NYC\"}"
}
})]);
normalize_tool_arguments_in_messages(&mut messages);
assert_eq!(
messages[0]["function_call"]["arguments"],
serde_json::json!({"location": "NYC"})
);
}
#[test]
fn test_normalize_tool_arguments_malformed_json_passthrough() {
// Malformed JSON should be left as a string
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
"role": "assistant",
"tool_calls": [{
"type": "function",
"function": {
"name": "f",
"arguments": "not valid json at all"
}
}]
})]);
normalize_tool_arguments_in_messages(&mut messages);
assert_eq!(
messages[0]["tool_calls"][0]["function"]["arguments"],
serde_json::Value::String("not valid json at all".to_string())
);
}
#[test]
fn test_normalize_tool_arguments_with_multimodal_content() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Check this:"},
{"type": "video_url", "video_url": {"url": "https://example.com/vid.mp4"}},
{"type": "text", "text": "Interesting?"}
]
},
{
"role": "assistant",
"tool_calls": [{
"id": "call_123",
"type": "function",
"function": {
"name": "analyze_video",
"arguments": "{\"url\":\"https://example.com/vid.mp4\",\"format\":\"mp4\"}"
}
}]
}
]
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let mut messages = serde_json::to_value(request.messages()).unwrap();
normalize_tool_arguments_in_messages(&mut messages);
// Multimodal content preserved as array
assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
// Tool arguments deserialized to object
assert!(messages[1]["tool_calls"][0]["function"]["arguments"].is_object());
assert_eq!(
messages[1]["tool_calls"][0]["function"]["arguments"]["url"],
"https://example.com/vid.mp4"
);
}
fn user() -> Msg { fn user() -> Msg {
Msg::User(Default::default()) Msg::User(Default::default())
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment