Unverified Commit 5e2f29f5 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

feat: Support for unary tool use in ChatCompletions API (#1800)

parent 6835dd7a
...@@ -291,6 +291,14 @@ async fn chat_completions( ...@@ -291,6 +291,14 @@ async fn chat_completions(
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
// Handle unsupported fields - if Some(resp) is returned by
// validate_chat_completion_unsupported_fields,
// then a field was used that is unsupported. We will log an error message
// and early return a 501 NOT_IMPLEMENTED status code. Otherwise, proceeed.
if let Some(resp) = validate_chat_completion_unsupported_fields(&request) {
return Ok(resp.into_response());
}
// Apply template values if present // Apply template values if present
if let Some(template) = template { if let Some(template) = template {
if request.inner.model.is_empty() { if request.inner.model.is_empty() {
...@@ -393,6 +401,41 @@ async fn chat_completions( ...@@ -393,6 +401,41 @@ async fn chat_completions(
} }
} }
/// Checks for unsupported fields in the request.
/// Returns Some(response) if unsupported fields are present.
#[allow(deprecated)]
pub fn validate_chat_completion_unsupported_fields(
request: &NvCreateChatCompletionRequest,
) -> Option<impl IntoResponse> {
let inner = &request.inner;
if inner.parallel_tool_calls == Some(true) {
return Some(ErrorResponse::not_implemented_error(
"`parallel_tool_calls: true` is not supported.",
));
}
if inner.stream == Some(true) && inner.tools.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`stream: true` is not supported when `tools` are provided.",
));
}
if inner.function_call.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`function_call` is deprecated. Please migrate to use `tool_choice` instead.",
));
}
if inner.functions.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`functions` is deprecated. Please migrate to use `tools` instead.",
));
}
None
}
/// OpenAI Responses Request Handler /// OpenAI Responses Request Handler
/// ///
/// This method will handle the incoming request for the /v1/responses endpoint. /// This method will handle the incoming request for the /v1/responses endpoint.
...@@ -407,7 +450,7 @@ async fn responses( ...@@ -407,7 +450,7 @@ async fn responses(
// Handle unsupported fields - if Some(resp) is returned by validate_unsupported_fields, // Handle unsupported fields - if Some(resp) is returned by validate_unsupported_fields,
// then a field was used that is unsupported. We will log an error message // then a field was used that is unsupported. We will log an error message
// and early return a 501 NOT_IMPLEMENTED status code. Otherwise, proceeed. // and early return a 501 NOT_IMPLEMENTED status code. Otherwise, proceeed.
if let Some(resp) = validate_unsupported_fields(&request) { if let Some(resp) = validate_response_unsupported_fields(&request) {
return Ok(resp.into_response()); return Ok(resp.into_response());
} }
...@@ -415,7 +458,7 @@ async fn responses( ...@@ -415,7 +458,7 @@ async fn responses(
// validate_input_is_text_only, then we are handling something other than Input::Text(_). // validate_input_is_text_only, then we are handling something other than Input::Text(_).
// We will log an error message and early return a 501 NOT_IMPLEMENTED status code. // We will log an error message and early return a 501 NOT_IMPLEMENTED status code.
// Otherwise, proceeed. // Otherwise, proceeed.
if let Some(resp) = validate_input_is_text_only(&request) { if let Some(resp) = validate_response_input_is_text_only(&request) {
return Ok(resp.into_response()); return Ok(resp.into_response());
} }
...@@ -504,7 +547,9 @@ async fn responses( ...@@ -504,7 +547,9 @@ async fn responses(
Ok(Json(response).into_response()) Ok(Json(response).into_response())
} }
pub fn validate_input_is_text_only(request: &NvCreateResponse) -> Option<impl IntoResponse> { pub fn validate_response_input_is_text_only(
request: &NvCreateResponse,
) -> Option<impl IntoResponse> {
match &request.inner.input { match &request.inner.input {
async_openai::types::responses::Input::Text(_) => None, async_openai::types::responses::Input::Text(_) => None,
_ => Some(ErrorResponse::not_implemented_error("Only `Input::Text` is supported. Structured, multimedia, or custom input types are not yet implemented.")), _ => Some(ErrorResponse::not_implemented_error("Only `Input::Text` is supported. Structured, multimedia, or custom input types are not yet implemented.")),
...@@ -513,7 +558,9 @@ pub fn validate_input_is_text_only(request: &NvCreateResponse) -> Option<impl In ...@@ -513,7 +558,9 @@ pub fn validate_input_is_text_only(request: &NvCreateResponse) -> Option<impl In
/// Checks for unsupported fields in the request. /// Checks for unsupported fields in the request.
/// Returns Some(response) if unsupported fields are present. /// Returns Some(response) if unsupported fields are present.
pub fn validate_unsupported_fields(request: &NvCreateResponse) -> Option<impl IntoResponse> { pub fn validate_response_unsupported_fields(
request: &NvCreateResponse,
) -> Option<impl IntoResponse> {
let inner = &request.inner; let inner = &request.inner;
if inner.background == Some(true) { if inner.background == Some(true) {
...@@ -936,7 +983,7 @@ mod tests { ...@@ -936,7 +983,7 @@ mod tests {
#[test] #[test]
fn test_validate_input_is_text_only_accepts_text() { fn test_validate_input_is_text_only_accepts_text() {
let request = make_base_request(); let request = make_base_request();
let result = validate_input_is_text_only(&request); let result = validate_response_input_is_text_only(&request);
assert!(result.is_none()); assert!(result.is_none());
} }
...@@ -948,14 +995,14 @@ mod tests { ...@@ -948,14 +995,14 @@ mod tests {
role: ResponseRole::User, role: ResponseRole::User,
content: InputContent::TextInput("structured".into()), content: InputContent::TextInput("structured".into()),
})]); })]);
let result = validate_input_is_text_only(&request); let result = validate_response_input_is_text_only(&request);
assert!(result.is_some()); assert!(result.is_some());
} }
#[test] #[test]
fn test_validate_unsupported_fields_accepts_clean_request() { fn test_validate_unsupported_fields_accepts_clean_request() {
let request = make_base_request(); let request = make_base_request();
let result = validate_unsupported_fields(&request); let result = validate_response_unsupported_fields(&request);
assert!(result.is_none()); assert!(result.is_none());
} }
...@@ -1025,7 +1072,7 @@ mod tests { ...@@ -1025,7 +1072,7 @@ mod tests {
for (field, set_field) in unsupported_cases { for (field, set_field) in unsupported_cases {
let mut req = make_base_request(); let mut req = make_base_request();
(set_field)(&mut req.inner); (set_field)(&mut req.inner);
let result = validate_unsupported_fields(&req); let result = validate_response_unsupported_fields(&req);
assert!(result.is_some(), "Expected rejection for `{field}`"); assert!(result.is_some(), "Expected rejection for `{field}`");
} }
} }
......
...@@ -13,14 +13,16 @@ ...@@ -13,14 +13,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::collections::HashMap;
use serde_json::Value;
use uuid::Uuid;
mod request; mod request;
mod response; mod response;
pub use request::*; pub use request::*;
pub use response::*; pub use response::*;
use serde_json::Value;
use std::collections::HashMap;
use uuid::Uuid;
/// Matches and processes tool calling patterns in LLM responses /// Matches and processes tool calling patterns in LLM responses
/// ///
...@@ -113,3 +115,257 @@ impl ToolCallingMatcher { ...@@ -113,3 +115,257 @@ impl ToolCallingMatcher {
} }
} }
} }
/// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format.
///
/// This is a flexible helper that handles a variety of potential formats emitted by LLMs for function/tool calls,
/// including wrapped payloads (`<TOOLCALL>[...]</TOOLCALL>`, `<|python_tag|>...`) and JSON representations
/// with either `parameters` or `arguments` fields.
///
/// # Supported Formats
///
/// The input `message` may be one of:
///
/// - `<TOOLCALL>[{ "name": ..., "parameters": { ... } }]</TOOLCALL>`
/// - `<|python_tag|>{ "name": ..., "arguments": { ... } }`
/// - Raw JSON of:
/// - `CalledFunctionParameters`: `{ "name": ..., "parameters": { ... } }`
/// - `CalledFunctionArguments`: `{ "name": ..., "arguments": { ... } }`
/// - Or a list of either of those types: `[ { "name": ..., "arguments": { ... } }, ... ]`
///
/// # Return
///
/// - `Ok(Some(ToolCallResponse))` if parsing succeeds
/// - `Ok(None)` if input format is unrecognized or invalid JSON
/// - `Err(...)` if JSON is valid but deserialization or argument re-serialization fails
///
/// # Note on List Handling
///
/// When the input contains a list of tool calls (either with `parameters` or `arguments`),
/// only the **last item** in the list is returned. This design choice assumes that the
/// most recent tool call in a list is the one to execute.
///
/// # Errors
///
/// Returns a `Result::Err` only if an inner `serde_json::to_string(...)` fails
/// (e.g., if the arguments are not serializable).
///
/// # Examples
///
/// ```ignore
/// let input = r#"<TOOLCALL>[{ "name": "search", "parameters": { "query": "rust" } }]</TOOLCALL>"#;
/// let result = try_parse_call_common(input)?;
/// assert!(result.is_some());
/// ```
pub fn try_parse_call_common(message: &str) -> anyhow::Result<Option<ToolCallResponse>> {
let trimmed = message.trim();
// Support <TOOLCALL>[ ... ] or <tool_call>[ ... ]
let json = if trimmed.starts_with("<TOOLCALL>[") && trimmed.ends_with("]</TOOLCALL>") {
tracing::debug!("Stripping <TOOLCALL> wrapper from tool call payload");
&trimmed["<TOOLCALL>[".len()..trimmed.len() - "]</TOOLCALL>".len()]
// Support custom/LLM-formatted `<|python_tag|>` preamble
} else if let Some(stripped) = trimmed.strip_prefix("<|python_tag|>") {
tracing::debug!("Stripping <|python_tag|> prefix from tool call payload");
stripped
// Otherwise, assume input is clean JSON
} else {
trimmed
};
// Anonymous function to attempt deserialization into a known representation
let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> {
Ok(ToolCallResponse {
id: format!("call-{}", Uuid::new_v4()),
tp: ToolCallType::Function,
function: CalledFunction {
name,
arguments: serde_json::to_string(&args)?,
},
})
};
// CalledFunctionParameters: Single { name, parameters }
// Example:
// {
// "name": "search_docs",
// "parameters": {
// "query": "how to use Rust",
// "limit": 5
// }
// }
if let Ok(single) = serde_json::from_str::<CalledFunctionParameters>(json) {
return parse(single.name, single.parameters).map(Some);
// CalledFunctinoArguments: Single { name, arguments }
// Example:
// {
// "name": "summarize",
// "arguments": {
// "text": "Rust is a systems programming language.",
// "length": "short"
// }
// }
} else if let Ok(single) = serde_json::from_str::<CalledFunctionArguments>(json) {
return parse(single.name, single.arguments).map(Some);
// Vec<CalledFunctionParameters>: List of { name, parameters }
// Example:
// [
// { "name": "lookup_user", "parameters": { "user_id": "123" } },
// { "name": "send_email", "parameters": { "to": "user@example.com", "subject": "Welcome!" } }
// ]
// We pop the last item in the list to use.
} else if let Ok(mut list) = serde_json::from_str::<Vec<CalledFunctionParameters>>(json) {
if let Some(item) = list.pop() {
return parse(item.name, item.parameters).map(Some);
}
// Vec<CalledFunctionArguments>: List of { name, arguments }
// Example:
// [
// {
// "name": "get_weather",
// "arguments": {
// "location": "San Francisco",
// "units": "celsius"
// }
// }
// ]
// Again, we take the last item for processing.
} else if let Ok(mut list) = serde_json::from_str::<Vec<CalledFunctionArguments>>(json) {
if let Some(item) = list.pop() {
return parse(item.name, item.arguments).map(Some);
}
}
Ok(None)
}
/// Try parsing a string as a structured tool call, for aggregation usage.
///
/// If successful, returns a `ChatCompletionMessageToolCall`.
pub fn try_parse_tool_call_aggregate(
message: &str,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCall>> {
let parsed = try_parse_call_common(message)?;
if let Some(parsed) = parsed {
Ok(Some(async_openai::types::ChatCompletionMessageToolCall {
id: parsed.id,
r#type: async_openai::types::ChatCompletionToolType::Function,
function: async_openai::types::FunctionCall {
name: parsed.function.name,
arguments: parsed.function.arguments,
},
}))
} else {
Ok(None)
}
}
/// Try parsing a string as a structured tool call, for streaming (delta) usage.
///
/// If successful, returns a `ChatCompletionMessageToolCallChunk`.
pub fn try_parse_tool_call_stream(
message: &str,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCallChunk>> {
let parsed = try_parse_call_common(message)?;
if let Some(parsed) = parsed {
Ok(Some(
async_openai::types::ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(parsed.id),
r#type: Some(async_openai::types::ChatCompletionToolType::Function),
function: Some(async_openai::types::FunctionCallStream {
name: Some(parsed.function.name),
arguments: Some(parsed.function.arguments),
}),
},
))
} else {
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) {
let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap();
(call.function.name, args)
}
#[test]
fn parses_single_parameters_object() {
let input = r#"{ "name": "hello", "parameters": { "x": 1, "y": 2 } }"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "hello");
assert_eq!(args["x"], 1);
assert_eq!(args["y"], 2);
}
#[test]
fn parses_single_arguments_object() {
let input = r#"{ "name": "world", "arguments": { "a": "abc", "b": 42 } }"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "world");
assert_eq!(args["a"], "abc");
assert_eq!(args["b"], 42);
}
#[test]
fn parses_vec_of_parameters_and_takes_last() {
let input = r#"[{ "name": "first", "parameters": { "a": 1 } }, { "name": "second", "parameters": { "b": 2 } }]"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "second");
assert_eq!(args["b"], 2);
}
#[test]
fn parses_vec_of_arguments_and_takes_last() {
let input = r#"[{ "name": "alpha", "arguments": { "a": "x" } }, { "name": "omega", "arguments": { "z": "y" } }]"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "omega");
assert_eq!(args["z"], "y");
}
#[test]
fn parses_toolcall_wrapped_payload() {
let input =
r#"<TOOLCALL>[{ "name": "wrapped", "parameters": { "foo": "bar" } }]</TOOLCALL>"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "wrapped");
assert_eq!(args["foo"], "bar");
}
#[test]
fn parses_python_tag_prefixed_payload() {
let input = r#"<|python_tag|>{ "name": "pyfunc", "arguments": { "k": "v" } }"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "pyfunc");
assert_eq!(args["k"], "v");
}
#[test]
fn returns_none_on_invalid_input() {
let input = r#"not even json"#;
let result = try_parse_call_common(input).unwrap();
assert!(result.is_none());
}
#[test]
fn returns_none_on_valid_json_wrong_shape() {
let input = r#"{ "foo": "bar" }"#;
let result = try_parse_call_common(input).unwrap();
assert!(result.is_none());
}
}
...@@ -60,6 +60,8 @@ struct DeltaChoice { ...@@ -60,6 +60,8 @@ struct DeltaChoice {
finish_reason: Option<async_openai::types::FinishReason>, finish_reason: Option<async_openai::types::FinishReason>,
/// Optional log probabilities for the chat choice. /// Optional log probabilities for the chat choice.
logprobs: Option<async_openai::types::ChatChoiceLogprobs>, logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
// Optional tool calls for the chat choice.
tool_calls: Option<Vec<async_openai::types::ChatCompletionMessageToolCall>>,
} }
impl Default for DeltaAggregator { impl Default for DeltaAggregator {
...@@ -135,6 +137,7 @@ impl DeltaAggregator { ...@@ -135,6 +137,7 @@ impl DeltaAggregator {
role: choice.delta.role, role: choice.delta.role,
finish_reason: None, finish_reason: None,
logprobs: choice.logprobs, logprobs: choice.logprobs,
tool_calls: None,
}); });
// Append content if available. // Append content if available.
...@@ -153,12 +156,32 @@ impl DeltaAggregator { ...@@ -153,12 +156,32 @@ impl DeltaAggregator {
.await; .await;
// Return early if an error was encountered. // Return early if an error was encountered.
let aggregator = if let Some(error) = aggregator.error { let mut aggregator = if let Some(error) = aggregator.error {
return Err(error); return Err(error);
} else { } else {
aggregator aggregator
}; };
// After aggregation, inspect each choice's text for tool call syntax
for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none() {
if let Ok(Some(tool_call)) =
crate::preprocessor::tools::try_parse_tool_call_aggregate(&choice.text)
{
tracing::debug!(
tool_call_id = %tool_call.id,
function_name = %tool_call.function.name,
arguments = %tool_call.function.arguments,
"Parsed structured tool call from aggregated content"
);
choice.tool_calls = Some(vec![tool_call]);
choice.text.clear();
choice.finish_reason = Some(async_openai::types::FinishReason::ToolCalls);
}
}
}
// Extract aggregated choices and sort them by index. // Extract aggregated choices and sort them by index.
let mut choices: Vec<_> = aggregator let mut choices: Vec<_> = aggregator
.choices .choices
...@@ -196,8 +219,12 @@ impl From<DeltaChoice> for async_openai::types::ChatChoice { ...@@ -196,8 +219,12 @@ impl From<DeltaChoice> for async_openai::types::ChatChoice {
async_openai::types::ChatChoice { async_openai::types::ChatChoice {
message: async_openai::types::ChatCompletionResponseMessage { message: async_openai::types::ChatCompletionResponseMessage {
role: delta.role.expect("delta should have a Role"), role: delta.role.expect("delta should have a Role"),
content: Some(delta.text), content: if delta.tool_calls.is_some() {
tool_calls: None, None
} else {
Some(delta.text)
},
tool_calls: delta.tool_calls,
refusal: None, refusal: None,
function_call: None, function_call: None,
audio: None, audio: None,
......
...@@ -130,16 +130,15 @@ impl DeltaGenerator { ...@@ -130,16 +130,15 @@ impl DeltaGenerator {
finish_reason: Option<async_openai::types::FinishReason>, finish_reason: Option<async_openai::types::FinishReason>,
logprobs: Option<async_openai::types::ChatChoiceLogprobs>, logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
) -> async_openai::types::CreateChatCompletionStreamResponse { ) -> async_openai::types::CreateChatCompletionStreamResponse {
// TODO: Update for tool calling
let delta = async_openai::types::ChatCompletionStreamResponseDelta { let delta = async_openai::types::ChatCompletionStreamResponseDelta {
content: text,
function_call: None,
tool_calls: None,
role: if self.msg_counter == 0 { role: if self.msg_counter == 0 {
Some(async_openai::types::Role::Assistant) Some(async_openai::types::Role::Assistant)
} else { } else {
None None
}, },
content: text,
tool_calls: None,
function_call: None,
refusal: None, refusal: None,
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment