Unverified Commit c319aa2d authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: added tool call schema validation in oai formatter (#2935)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent be48b4cf
...@@ -24,6 +24,66 @@ use tracing; ...@@ -24,6 +24,66 @@ use tracing;
use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput}; use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput};
fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
// No need to validate or enforce other schema checks as the basic Named function schema is already validated while creating the request.
// Empty parameters is allowed by OpenAI at request level. Need to enforce it at template level.
// Whenever parameters is empty, insert "type": "object" and "properties": {}
let mut updated_tools = Vec::new();
if let Some(arr) = tools.as_array() {
for tool in arr {
let mut tool = tool.clone();
if let Some(function) = tool.get_mut("function")
&& let Some(parameters) = function.get_mut("parameters")
{
// Only operate if parameters is an object
if parameters.is_object() {
let mut needs_type = false;
let mut needs_properties = false;
let is_empty = parameters
.as_object()
.map(|o| o.is_empty())
.unwrap_or(false);
// If empty, we need to insert both
if is_empty {
needs_type = true;
needs_properties = true;
} else {
// If not empty, check if type/properties are missing
if let Some(obj) = parameters.as_object() {
if !obj.contains_key("type") {
needs_type = true;
}
if !obj.contains_key("properties") {
needs_properties = true;
}
}
}
if (needs_type || needs_properties)
&& let Some(obj) = parameters.as_object_mut()
{
if needs_type {
obj.insert(
"type".to_string(),
serde_json::Value::String("object".to_string()),
);
}
if needs_properties {
obj.insert(
"properties".to_string(),
serde_json::Value::Object(Default::default()),
);
}
}
}
}
updated_tools.push(tool);
}
}
Some(Value::from_serialize(&updated_tools))
}
impl OAIChatLikeRequest for NvCreateChatCompletionRequest { impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn model(&self) -> String { fn model(&self) -> String {
self.inner.model.clone() self.inner.model.clone()
...@@ -37,7 +97,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest { ...@@ -37,7 +97,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
if self.inner.tools.is_none() { if self.inner.tools.is_none() {
None None
} else { } else {
Some(Value::from_serialize(&self.inner.tools)) // Try to fix the tool schema if it is missing type and properties
Some(may_be_fix_tool_schema(
serde_json::to_value(&self.inner.tools).unwrap(),
)?)
} }
} }
...@@ -160,12 +223,120 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { ...@@ -160,12 +223,120 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
}}; }};
let tmpl = if has_tools { let tmpl: minijinja::Template<'_, '_> = if has_tools {
self.env.get_template("tool_use")? self.env.get_template("tool_use")?
} else { } else {
self.env.get_template("default")? self.env.get_template("default")?
}; };
Ok(tmpl.render(&ctx)?) Ok(tmpl.render(&ctx)?)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_may_be_fix_tool_schema_missing_type_and_properties() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {},
"strict": null
}
}
]
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let tools = serde_json::to_value(request.tools()).unwrap();
assert!(tools[0]["function"]["parameters"]["type"] == "object");
assert!(
tools[0]["function"]["parameters"]["properties"]
== serde_json::Value::Object(Default::default())
);
}
#[test]
fn test_may_be_fix_tool_schema_missing_type() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"properties": {
"location": {
"type": "string",
"description": "City and state, e.g., 'San Francisco, CA'"
}
}
},
"strict": null
}
}
]
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let tools = serde_json::to_value(request.tools()).unwrap();
assert_eq!(tools[0]["function"]["parameters"]["type"], "object");
let mut expected_properties = serde_json::Map::new();
let mut location = serde_json::Map::new();
location.insert(
"type".to_string(),
serde_json::Value::String("string".to_string()),
);
location.insert(
"description".to_string(),
serde_json::Value::String("City and state, e.g., 'San Francisco, CA'".to_string()),
);
expected_properties.insert("location".to_string(), serde_json::Value::Object(location));
assert_eq!(
tools[0]["function"]["parameters"]["properties"],
serde_json::Value::Object(expected_properties)
);
}
#[test]
fn test_may_be_fix_tool_schema_missing_properties() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {"type": "object"},
"strict": null
}
}
]
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let tools = serde_json::to_value(request.tools()).unwrap();
assert_eq!(
tools[0]["function"]["parameters"]["properties"],
serde_json::Value::Object(Default::default())
);
assert_eq!(tools[0]["function"]["parameters"]["type"], "object");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment