Unverified Commit d1676cd4 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][tool call] Full support for ToolChoice (#11085)


Co-authored-by: default avatarSimo Lin <linsimo.mark@gmail.com>
parent 33b3c0f8
......@@ -1491,6 +1491,7 @@ impl ResponsesResponse {
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(),
ToolChoice::Function { .. } => "function".to_string(),
ToolChoice::AllowedTools { mode, .. } => mode.clone(),
},
tools: request.tools.clone(),
top_p: request.top_p,
......@@ -1718,6 +1719,12 @@ pub enum ToolChoice {
tool_type: String, // "function"
function: FunctionChoice,
},
AllowedTools {
#[serde(rename = "type")]
tool_type: String, // "allowed_tools"
mode: String, // "auto" | "required" TODO: need validation
tools: Vec<ToolReference>,
},
}
impl Default for ToolChoice {
......@@ -1732,6 +1739,14 @@ pub struct FunctionChoice {
pub name: String,
}
/// Tool reference for ToolChoice::AllowedTools
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolReference {
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub name: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Tool {
#[serde(rename = "type")]
......
// gRPC Router Implementation
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
......@@ -20,8 +21,9 @@ use crate::policies::PolicyRegistry;
use crate::protocols::spec::ChatMessage;
use crate::protocols::spec::{
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams,
ResponsesRequest, StringOrArray, Tool, ToolChoice, Usage,
CompletionRequest, EmbeddingRequest, FunctionCallResponse, GenerateRequest, RerankRequest,
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolChoice,
ToolChoiceValue, Usage,
};
use crate::reasoning_parser::ParserFactory;
use crate::routers::RouterTrait;
......@@ -34,7 +36,7 @@ use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer;
use crate::tool_parser::ParserRegistry;
use proto::generate_response::Response::{Chunk, Complete, Error};
use serde_json::{json, Value};
use serde_json::{json, Map, Value};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use tokio_stream::StreamExt;
use uuid::Uuid;
......@@ -132,8 +134,39 @@ impl GrpcRouter {
Err(response) => return response,
};
// Step 3: Process messages and apply chat template
let processed_messages = match self.process_chat_messages(body) {
// Step 3: Filter tools if needed for allowed_tools or specific function
// Only clone body if we need to modify tools
let mut body_with_filtered_tools;
let body_ref = match &body.tool_choice {
Some(ToolChoice::AllowedTools { tools: allowed, .. }) if body.tools.is_some() => {
body_with_filtered_tools = body.clone();
let all_tools = body_with_filtered_tools.tools.as_ref().unwrap();
let allowed_names: std::collections::HashSet<&str> =
allowed.iter().map(|t| t.name.as_str()).collect();
let filtered_tools: Vec<Tool> = all_tools
.iter()
.filter(|t| allowed_names.contains(t.function.name.as_str()))
.cloned()
.collect();
body_with_filtered_tools.tools = Some(filtered_tools);
&body_with_filtered_tools
}
Some(ToolChoice::Function { function, .. }) if body.tools.is_some() => {
body_with_filtered_tools = body.clone();
let all_tools = body_with_filtered_tools.tools.as_ref().unwrap();
let filtered_tools: Vec<Tool> = all_tools
.iter()
.filter(|t| t.function.name == function.name)
.cloned()
.collect();
body_with_filtered_tools.tools = Some(filtered_tools);
&body_with_filtered_tools
}
_ => body, // No filtering needed, use original
};
// Step 4: Process messages and apply chat template
let processed_messages = match self.process_chat_messages(body_ref) {
Ok(msgs) => msgs,
Err(e) => {
error!("Failed to process chat messages: {}", e);
......@@ -141,7 +174,7 @@ impl GrpcRouter {
}
};
// Step 4: Tokenize the processed text
// Step 5: Tokenize the processed text
let encoding = match self.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
......@@ -157,18 +190,17 @@ impl GrpcRouter {
let token_ids = encoding.token_ids().to_vec();
debug!("Tokenized {} tokens from input", token_ids.len());
// Step 5: Build tool constraints if needed
let tool_call_constraint = if let Some(tools) = &body.tools {
// Step 6: Build tool constraints if needed
// body_ref already has filtered tools if needed
let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| {
self.generate_tool_constraints(tools, &body.tool_choice, &body.model)
} else {
None
};
});
// Step 6: Build the base gRPC request
// Step 7: Build the base gRPC request (use body_ref with filtered tools if applicable)
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let request = match client.build_generate_request(
request_id,
body,
body_ref,
processed_messages.text.clone(),
token_ids,
processed_messages.multimodal_inputs,
......@@ -561,16 +593,227 @@ impl GrpcRouter {
}
/// Generate tool constraints for structured generation
/// Note: tools should already be filtered if needed (by allowed_tools or specific function)
fn generate_tool_constraints(
&self,
_tools: &[Tool],
_tool_choice: &Option<ToolChoice>,
model: &str,
tools: &[Tool],
tool_choice: &Option<ToolChoice>,
_model: &str,
) -> Option<(String, String)> {
let _parser = self.tool_parser_registry.get_parser(model)?;
// TODO: Implement actual constraint generation logic
// For now, return None as this is placeholder implementation
None
let choice = tool_choice.as_ref()?;
match choice {
// Specific function: Return parameters schema directly
// tools should already be filtered to contain only the specific function
ToolChoice::Function { .. } => {
if tools.is_empty() {
return None;
}
let tool = &tools[0];
// Return the tool's parameters schema directly (not wrapped in array)
let params_schema = serde_json::to_string(&tool.function.parameters).ok()?;
Some(("json_schema".to_string(), params_schema))
}
// Required: Array of tool calls with minItems: 1
ToolChoice::Value(ToolChoiceValue::Required) => {
let schema = self.build_required_array_schema(tools)?;
Some(("json_schema".to_string(), schema))
}
// AllowedTools with required mode: tools are already filtered
ToolChoice::AllowedTools { mode, .. } => {
if mode == "required" {
if tools.is_empty() {
return None;
}
let schema = self.build_required_array_schema(tools)?;
Some(("json_schema".to_string(), schema))
} else {
// "auto" mode - no constraint needed
None
}
}
// "auto" or "none" - no constraint
_ => None,
}
}
/// Build JSON schema for required tool calls (array with minItems: 1)
/// Includes $defs consolidation from all tools (matching Python's behavior)
fn build_required_array_schema(&self, tools: &[Tool]) -> Option<String> {
// Build anyOf schemas for each tool
let mut any_of_schemas = Vec::new();
for tool in tools {
let tool_schema = json!({
"properties": {
"name": {
"type": "string",
"enum": [tool.function.name]
},
"parameters": tool.function.parameters
},
"required": ["name", "parameters"]
});
any_of_schemas.push(tool_schema);
}
// Consolidate $defs from all tools (matching Python's _get_tool_schema_defs)
let mut all_defs: HashMap<String, Value> = HashMap::new();
for tool in tools {
if let Value::Object(params) = &tool.function.parameters {
if let Some(Value::Object(defs)) = params.get("$defs") {
for (def_name, def_schema) in defs {
if let Some(existing) = all_defs.get(def_name) {
// Check for conflicts
if existing != def_schema {
error!(
"Tool definition '{}' has multiple schemas, which is not supported",
def_name
);
return None;
}
} else {
all_defs.insert(def_name.clone(), def_schema.clone());
}
}
}
}
}
// Build the full array schema
let mut array_schema = json!({
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": any_of_schemas
}
});
// Add $defs if any were found (matching Python's behavior)
if !all_defs.is_empty() {
if let Value::Object(ref mut schema_obj) = array_schema {
let defs_value =
Value::Object(all_defs.into_iter().collect::<Map<String, Value>>());
schema_obj.insert("$defs".to_string(), defs_value);
}
}
serde_json::to_string(&array_schema).ok()
}
/// Parse tool calls from JSON schema constrained response
fn parse_json_schema_response(
&self,
processed_text: &str,
tool_choice: &Option<ToolChoice>,
) -> (Option<Vec<ToolCall>>, String) {
match tool_choice {
Some(ToolChoice::Function { function, .. }) => {
// Specific function: Parse parameters directly
match serde_json::from_str::<Value>(processed_text) {
Ok(params) => {
let tool_call = ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4()),
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: function.name.clone(),
arguments: Some(
serde_json::to_string(&params)
.unwrap_or_else(|_| "{}".to_string()),
),
},
};
(Some(vec![tool_call]), String::new())
}
Err(e) => {
error!("Failed to parse specific function parameters: {}", e);
(None, processed_text.to_string())
}
}
}
Some(ToolChoice::Value(ToolChoiceValue::Required))
| Some(ToolChoice::AllowedTools { .. }) => {
// Required mode: Parse array of tool calls
match serde_json::from_str::<Vec<Value>>(processed_text) {
Ok(parsed_array) => {
let spec_tool_calls: Vec<ToolCall> = parsed_array
.into_iter()
.enumerate()
.filter_map(|(i, item)| {
let obj = item.as_object()?;
let name = obj.get("name")?.as_str()?.to_string();
let parameters = obj.get("parameters")?;
Some(ToolCall {
id: format!("call_{}_{}", i, uuid::Uuid::new_v4()),
tool_type: "function".to_string(),
function: FunctionCallResponse {
name,
arguments: Some(
serde_json::to_string(parameters)
.unwrap_or_else(|_| "{}".to_string()),
),
},
})
})
.collect();
(Some(spec_tool_calls), String::new())
}
Err(e) => {
error!("Failed to parse required tool call array: {}", e);
(None, processed_text.to_string())
}
}
}
_ => (None, processed_text.to_string()),
}
}
/// Parse tool calls using model-specific parser
async fn parse_with_model_parser(
&self,
processed_text: &str,
model: &str,
) -> (Option<Vec<ToolCall>>, String) {
let Some(parser) = self.tool_parser_registry.get_parser(model) else {
return (None, processed_text.to_string());
};
if !parser.detect_format(processed_text) {
return (None, processed_text.to_string());
}
match parser.parse_complete(processed_text).await {
Ok((normal_text, parsed_tool_calls)) => {
if parsed_tool_calls.is_empty() {
return (None, normal_text);
}
let spec_tool_calls = parsed_tool_calls
.into_iter()
.map(|tc| ToolCall {
id: tc.id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
})
.collect();
(Some(spec_tool_calls), normal_text)
}
Err(e) => {
error!("Tool call parsing error: {}", e);
(None, processed_text.to_string())
}
}
}
/// Resolve the generate input into optional original text and token IDs
......@@ -1130,36 +1373,21 @@ impl GrpcRouter {
);
if tool_choice_enabled && original_request.tools.is_some() {
if let Some(parser) = self
.tool_parser_registry
.get_parser(&original_request.model)
{
match parser.parse_complete(&processed_text).await {
Ok((normal_text, parsed_tool_calls)) => {
if !parsed_tool_calls.is_empty() {
let spec_tool_calls = parsed_tool_calls
.into_iter()
.map(|tc| crate::protocols::spec::ToolCall {
id: tc.id,
tool_type: "function".to_string(),
function: crate::protocols::spec::FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
})
.collect();
tool_calls = Some(spec_tool_calls);
processed_text = normal_text;
}
}
Err(e) => {
error!("Tool call parsing error: {}", e);
// Continue without tool calls rather than failing
}
}
// Check if JSON schema constraint was used (specific function or required mode)
let used_json_schema = match &original_request.tool_choice {
Some(ToolChoice::Function { .. }) => true,
Some(ToolChoice::Value(crate::protocols::spec::ToolChoiceValue::Required)) => true,
Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required",
_ => false,
};
if used_json_schema {
(tool_calls, processed_text) =
self.parse_json_schema_response(&processed_text, &original_request.tool_choice);
} else {
(tool_calls, processed_text) = self
.parse_with_model_parser(&processed_text, &original_request.model)
.await;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment