Unverified Commit 59f474a2 authored by davilu-nvidia's avatar davilu-nvidia Committed by GitHub
Browse files

feat(async-openai): accept dict format for tool-call arguments (#7772)


Signed-off-by: default avatardavilu-nvidia <davilu@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 3efc733f
......@@ -29,7 +29,6 @@ pub use async_openai::types::chat::{
ChatCompletionFunctionCall,
ChatCompletionFunctions,
ChatCompletionFunctionsArgs,
ChatCompletionMessageToolCallChunk,
ChatCompletionRequestAssistantMessageAudio,
ChatCompletionRequestAssistantMessageContent,
ChatCompletionRequestAssistantMessageContentPart,
......@@ -56,8 +55,6 @@ pub use async_openai::types::chat::{
CompletionFinishReason,
CompletionTokensDetails,
CompletionUsage,
FunctionCall,
FunctionCallStream,
FunctionObject,
FunctionObjectArgs,
InputAudio,
......@@ -90,6 +87,90 @@ pub use async_openai::types::chat::FinishReason;
// Re-export both names for compatibility.
pub use async_openai::types::chat::FunctionType;
// ---------------------------------------------------------------------------
// Flexible `arguments` deserialisation helpers
// ---------------------------------------------------------------------------
// Some agent frameworks (e.g. LangChain, custom harnesses) send tool-call
// arguments as a pre-parsed JSON object instead of the canonical JSON
// string. The helpers below normalise both representations to a `String` so
// downstream code never needs to branch on the wire format.
fn deserialize_arguments<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let value = serde_json::Value::deserialize(deserializer)?;
match value {
serde_json::Value::String(s) => Ok(s),
v @ serde_json::Value::Object(_) => {
// serde_json::to_string on a Value is infallible
Ok(serde_json::to_string(&v).unwrap())
}
other => Err(D::Error::custom(format!(
"expected string or object for `arguments`, got {other}"
))),
}
}
fn deserialize_arguments_opt<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let value = Option::<serde_json::Value>::deserialize(deserializer)?;
match value {
None => Ok(None),
Some(serde_json::Value::String(s)) => Ok(Some(s)),
Some(v @ serde_json::Value::Object(_)) => serde_json::to_string(&v)
.map(Some)
.map_err(|e| D::Error::custom(e.to_string())),
Some(other) => Err(D::Error::custom(format!(
"expected string or object for `arguments`, got {other}"
))),
}
}
// ---------------------------------------------------------------------------
// FunctionCall / FunctionCallStream — local definitions with flexible deser
// ---------------------------------------------------------------------------
// Upstream `async-openai` only accepts a JSON string for `arguments`.
// We define these locally so we can attach `#[serde(deserialize_with)]` and
// accept both string and object representations on the wire.
/// The name and arguments of a function that should be called.
///
/// Accepts `arguments` as either a JSON string (`"{\"key\":\"value\"}"`) or a
/// JSON object (`{"key": "value"}`); both are normalised to a JSON string
/// on deserialisation so callers always see the canonical form.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)]
pub struct FunctionCall {
pub name: String,
#[serde(deserialize_with = "deserialize_arguments")]
pub arguments: String,
}
/// Streaming variant of [`FunctionCall`] where both fields are optional.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)]
pub struct FunctionCallStream {
pub name: Option<String>,
#[serde(default, deserialize_with = "deserialize_arguments_opt")]
pub arguments: Option<String>,
}
/// Streaming tool-call chunk.
///
/// Defined locally (instead of re-exporting from upstream) because its
/// `function` field references our local [`FunctionCallStream`] with the
/// flexible `arguments` deserialiser.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)]
pub struct ChatCompletionMessageToolCallChunk {
pub index: u32,
pub id: Option<String>,
pub r#type: Option<FunctionType>,
pub function: Option<FunctionCallStream>,
}
// ---------------------------------------------------------------------------
// Types with structural differences from upstream (kept locally)
// ---------------------------------------------------------------------------
......@@ -658,6 +739,7 @@ pub struct ChatCompletionStreamResponseDelta {
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct ChatCompletionStreamResponseDeltaFunctionCall {
pub name: Option<String>,
#[serde(default, deserialize_with = "deserialize_arguments_opt")]
pub arguments: Option<String>,
}
......@@ -722,4 +804,139 @@ mod tests {
let json = serde_json::to_value(tool_call).unwrap();
assert_eq!(json["type"], "function");
}
// -- dict-format arguments tests --
#[test]
fn function_call_accepts_string_arguments() {
let fc: FunctionCall = serde_json::from_value(serde_json::json!({
"name": "get_weather",
"arguments": "{\"location\":\"SF\"}"
}))
.unwrap();
assert_eq!(fc.arguments, "{\"location\":\"SF\"}");
}
#[test]
fn function_call_accepts_dict_arguments() {
let fc: FunctionCall = serde_json::from_value(serde_json::json!({
"name": "get_weather",
"arguments": {"location": "SF"}
}))
.unwrap();
assert_eq!(fc.arguments, "{\"location\":\"SF\"}");
}
#[test]
fn function_call_rejects_integer_arguments() {
let result = serde_json::from_value::<FunctionCall>(serde_json::json!({
"name": "f",
"arguments": 42
}));
assert!(result.is_err());
}
#[test]
fn function_call_rejects_boolean_arguments() {
let result = serde_json::from_value::<FunctionCall>(serde_json::json!({
"name": "f",
"arguments": true
}));
assert!(result.is_err());
}
#[test]
fn function_call_rejects_null_arguments() {
let result = serde_json::from_value::<FunctionCall>(serde_json::json!({
"name": "f",
"arguments": null
}));
assert!(result.is_err());
}
#[test]
fn function_call_rejects_array_arguments() {
let result = serde_json::from_value::<FunctionCall>(serde_json::json!({
"name": "f",
"arguments": [1, 2, 3]
}));
assert!(result.is_err());
}
#[test]
fn function_call_stream_null_arguments_produces_none() {
let fcs: FunctionCallStream = serde_json::from_value(serde_json::json!({
"name": "f",
"arguments": null
}))
.unwrap();
assert_eq!(fcs.arguments, None);
}
#[test]
fn function_call_stream_rejects_integer_arguments() {
let result = serde_json::from_value::<FunctionCallStream>(serde_json::json!({
"name": "f",
"arguments": 42
}));
assert!(result.is_err());
}
#[test]
fn function_call_stream_rejects_boolean_arguments() {
let result = serde_json::from_value::<FunctionCallStream>(serde_json::json!({
"name": "f",
"arguments": true
}));
assert!(result.is_err());
}
#[test]
fn function_call_stream_accepts_dict_arguments() {
let fcs: FunctionCallStream = serde_json::from_value(serde_json::json!({
"name": "get_weather",
"arguments": {"location": "SF"}
}))
.unwrap();
assert_eq!(fcs.arguments.as_deref(), Some("{\"location\":\"SF\"}"));
}
#[test]
fn function_call_stream_accepts_null_arguments() {
let fcs: FunctionCallStream = serde_json::from_value(serde_json::json!({
"name": "get_weather"
}))
.unwrap();
assert_eq!(fcs.arguments, None);
}
#[test]
fn tool_call_with_dict_arguments_roundtrip() {
let tc: ChatCompletionMessageToolCall = serde_json::from_value(serde_json::json!({
"id": "call_abc",
"type": "function",
"function": {
"name": "search",
"arguments": {"query": "hello", "limit": 10}
}
}))
.unwrap();
// Compare as parsed JSON values since key order is non-deterministic
let parsed: serde_json::Value = serde_json::from_str(&tc.function.arguments).unwrap();
assert_eq!(parsed, serde_json::json!({"query": "hello", "limit": 10}));
// Re-serialisation produces a string, not an object
let json = serde_json::to_value(&tc).unwrap();
assert!(json["function"]["arguments"].is_string());
}
#[test]
fn stream_delta_function_call_accepts_dict_arguments() {
let delta: ChatCompletionStreamResponseDeltaFunctionCall =
serde_json::from_value(serde_json::json!({
"name": "get_weather",
"arguments": {"location": "SF"}
}))
.unwrap();
assert_eq!(delta.arguments.as_deref(), Some("{\"location\":\"SF\"}"));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment