Unverified Commit 9f5e7018 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Implement tool_choice support for Responses API (#12668)

parent cbf23dbb
...@@ -303,6 +303,9 @@ impl SglangSchedulerClient { ...@@ -303,6 +303,9 @@ impl SglangSchedulerClient {
} }
/// Build a GenerateRequest from ResponsesRequest (OpenAI Responses API) /// Build a GenerateRequest from ResponsesRequest (OpenAI Responses API)
///
/// NOTE: This is used by the Harmony router only. The Regular router uses
/// responses_to_chat() conversion and goes through the chat pipeline.
pub fn build_generate_request_from_responses( pub fn build_generate_request_from_responses(
&self, &self,
request_id: String, request_id: String,
...@@ -310,9 +313,11 @@ impl SglangSchedulerClient { ...@@ -310,9 +313,11 @@ impl SglangSchedulerClient {
processed_text: String, processed_text: String,
token_ids: Vec<u32>, token_ids: Vec<u32>,
harmony_stop_ids: Option<Vec<u32>>, harmony_stop_ids: Option<Vec<u32>>,
tool_call_constraint: Option<(String, String)>,
) -> Result<proto::GenerateRequest, String> { ) -> Result<proto::GenerateRequest, String> {
// Build sampling params from ResponsesRequest // Build sampling params from ResponsesRequest
let mut sampling_params = self.build_grpc_sampling_params_from_responses(body)?; let mut sampling_params =
self.build_grpc_sampling_params_from_responses(body, tool_call_constraint)?;
// Inject Harmony stop token IDs if provided // Inject Harmony stop token IDs if provided
if let Some(stop_ids) = harmony_stop_ids { if let Some(stop_ids) = harmony_stop_ids {
...@@ -441,9 +446,10 @@ impl SglangSchedulerClient { ...@@ -441,9 +446,10 @@ impl SglangSchedulerClient {
fn build_grpc_sampling_params_from_responses( fn build_grpc_sampling_params_from_responses(
&self, &self,
request: &ResponsesRequest, request: &ResponsesRequest,
tool_call_constraint: Option<(String, String)>,
) -> Result<proto::SamplingParams, String> { ) -> Result<proto::SamplingParams, String> {
// ResponsesRequest doesn't have stop sequences in the same way // ResponsesRequest doesn't have stop sequences in the same way
// Tools are handled externally by MCP loop, not via constraints // For Harmony router: Tools are handled via structural_tag constraints
let max_new_tokens = request.max_output_tokens.map(|v| v as i32); let max_new_tokens = request.max_output_tokens.map(|v| v as i32);
...@@ -462,12 +468,36 @@ impl SglangSchedulerClient { ...@@ -462,12 +468,36 @@ impl SglangSchedulerClient {
spaces_between_special_tokens: true, spaces_between_special_tokens: true,
ignore_eos: false, ignore_eos: false,
no_stop_trim: false, no_stop_trim: false,
n: 1, // Responses API doesn't support n>1 n: 1, // Responses API doesn't support n>1
constraint: None, // No constraints - tools handled by MCP constraint: self.build_constraint_for_responses(tool_call_constraint)?,
..Default::default() ..Default::default()
}) })
} }
/// Build constraint for Responses API (simpler than Chat API's build_constraint)
///
/// Responses API doesn't support response_format, ebnf, or regex constraints,
/// so this only handles tool_call_constraint.
fn build_constraint_for_responses(
&self,
tool_call_constraint: Option<(String, String)>,
) -> Result<Option<proto::sampling_params::Constraint>, String> {
if let Some((constraint_type, constraint_value)) = tool_call_constraint {
let tool_constraint = match constraint_type.as_str() {
"structural_tag" => {
proto::sampling_params::Constraint::StructuralTag(constraint_value)
}
"json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value),
"ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value),
"regex" => proto::sampling_params::Constraint::Regex(constraint_value),
_ => return Err(format!("Unknown constraint type: {}", constraint_type)),
};
Ok(Some(tool_constraint))
} else {
Ok(None)
}
}
fn build_single_constraint_from_plain( fn build_single_constraint_from_plain(
params: &GenerateSamplingParams, params: &GenerateSamplingParams,
) -> Result<Option<proto::sampling_params::Constraint>, String> { ) -> Result<Option<proto::sampling_params::Constraint>, String> {
......
...@@ -457,21 +457,43 @@ fn validate_chat_cross_parameters( ...@@ -457,21 +457,43 @@ fn validate_chat_cross_parameters(
return Err(e); return Err(e);
} }
// Validate that all referenced tool names exist in tools // Validate that all ToolReferences are Function type (Chat API only supports function tools)
for tool_ref in allowed_tools { for tool_ref in allowed_tools {
let tool_exists = tools.iter().any(|tool| { match tool_ref {
tool.tool_type == tool_ref.tool_type ToolReference::Function { name } => {
&& tool.function.name == tool_ref.name // Validate that the function exists in tools array
}); let tool_exists = tools.iter().any(|tool| {
tool.tool_type == "function" && tool.function.name == *name
if !tool_exists { });
let mut e =
validator::ValidationError::new("tool_choice_tool_not_found"); if !tool_exists {
e.message = Some(format!( let mut e = validator::ValidationError::new(
"Invalid value for 'tool_choice.tools': tool '{}' not found in 'tools'.", "tool_choice_tool_not_found",
tool_ref.name );
).into()); e.message = Some(
return Err(e); format!(
"Invalid value for 'tool_choice.tools': tool '{}' not found in 'tools'.",
name
)
.into(),
);
return Err(e);
}
}
_ => {
// Chat Completion API only supports function tools in tool_choice
let mut e = validator::ValidationError::new(
"tool_choice_invalid_tool_type",
);
e.message = Some(
format!(
"Invalid value for 'tool_choice.tools': Chat Completion API only supports function tools, got '{}'.",
tool_ref.identifier()
)
.into(),
);
return Err(e);
}
} }
} }
} }
......
...@@ -183,6 +183,18 @@ impl Default for ToolChoice { ...@@ -183,6 +183,18 @@ impl Default for ToolChoice {
} }
} }
impl ToolChoice {
/// Serialize tool_choice to string for ResponsesResponse
///
/// Returns the JSON-serialized tool_choice or "auto" as default
pub fn serialize_to_string(tool_choice: &Option<ToolChoice>) -> String {
tool_choice
.as_ref()
.map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
.unwrap_or_else(|| "auto".to_string())
}
}
/// Function choice specification for ToolChoice::Function /// Function choice specification for ToolChoice::Function
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionChoice { pub struct FunctionChoice {
...@@ -190,11 +202,73 @@ pub struct FunctionChoice { ...@@ -190,11 +202,73 @@ pub struct FunctionChoice {
} }
/// Tool reference for ToolChoice::AllowedTools /// Tool reference for ToolChoice::AllowedTools
///
/// Represents a reference to a specific tool in the allowed_tools array.
/// Different tool types have different required fields.
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolReference { #[serde(tag = "type")]
#[serde(rename = "type")] #[serde(rename_all = "snake_case")]
pub tool_type: String, // "function" pub enum ToolReference {
pub name: String, /// Reference to a function tool
#[serde(rename = "function")]
Function { name: String },
/// Reference to an MCP tool
#[serde(rename = "mcp")]
Mcp {
server_label: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
/// File search hosted tool
#[serde(rename = "file_search")]
FileSearch,
/// Web search preview hosted tool
#[serde(rename = "web_search_preview")]
WebSearchPreview,
/// Computer use preview hosted tool
#[serde(rename = "computer_use_preview")]
ComputerUsePreview,
/// Code interpreter hosted tool
#[serde(rename = "code_interpreter")]
CodeInterpreter,
/// Image generation hosted tool
#[serde(rename = "image_generation")]
ImageGeneration,
}
impl ToolReference {
/// Get a unique identifier for this tool reference
pub fn identifier(&self) -> String {
match self {
ToolReference::Function { name } => format!("function:{}", name),
ToolReference::Mcp { server_label, name } => {
if let Some(n) = name {
format!("mcp:{}:{}", server_label, n)
} else {
format!("mcp:{}", server_label)
}
}
ToolReference::FileSearch => "file_search".to_string(),
ToolReference::WebSearchPreview => "web_search_preview".to_string(),
ToolReference::ComputerUsePreview => "computer_use_preview".to_string(),
ToolReference::CodeInterpreter => "code_interpreter".to_string(),
ToolReference::ImageGeneration => "image_generation".to_string(),
}
}
/// Get the tool name if this is a function tool
pub fn function_name(&self) -> Option<&str> {
match self {
ToolReference::Function { name } => Some(name.as_str()),
_ => None,
}
}
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
......
...@@ -447,6 +447,7 @@ fn default_top_p() -> Option<f32> { ...@@ -447,6 +447,7 @@ fn default_top_p() -> Option<f32> {
// ============================================================================ // ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize, Validate)] #[derive(Debug, Clone, Deserialize, Serialize, Validate)]
#[validate(schema(function = "validate_responses_cross_parameters"))]
pub struct ResponsesRequest { pub struct ResponsesRequest {
/// Run the request in the background /// Run the request in the background
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -721,6 +722,83 @@ pub fn validate_conversation_id(conv_id: &str) -> Result<(), validator::Validati ...@@ -721,6 +722,83 @@ pub fn validate_conversation_id(conv_id: &str) -> Result<(), validator::Validati
Ok(()) Ok(())
} }
/// Schema-level validation for cross-field dependencies
fn validate_responses_cross_parameters(
request: &ResponsesRequest,
) -> Result<(), validator::ValidationError> {
use super::common::{ToolChoice, ToolReference};
// Only validate if both tools and tool_choice are present
if let (Some(tools), Some(tool_choice)) = (&request.tools, &request.tool_choice) {
// Extract function tool names from ResponseTools
let function_tool_names: Vec<&str> = tools
.iter()
.filter_map(|t| match t.r#type {
ResponseToolType::Function => t.function.as_ref().map(|f| f.name.as_str()),
_ => None,
})
.collect();
match tool_choice {
ToolChoice::Function { function, .. } => {
// Validate the specific function exists
if !function_tool_names.contains(&function.name.as_str()) {
let mut e = validator::ValidationError::new("tool_choice_function_not_found");
e.message = Some(
format!(
"Invalid value for 'tool_choice': function '{}' not found in 'tools'.",
function.name
)
.into(),
);
return Err(e);
}
}
ToolChoice::AllowedTools {
mode,
tools: allowed_tools,
..
} => {
// Validate mode is "auto" or "required"
if mode != "auto" && mode != "required" {
let mut e = validator::ValidationError::new("tool_choice_invalid_mode");
e.message = Some(
format!(
"Invalid value for 'tool_choice.mode': must be 'auto' or 'required', got '{}'.",
mode
)
.into(),
);
return Err(e);
}
// Validate that all function tool references exist
for tool_ref in allowed_tools {
if let ToolReference::Function { name } = tool_ref {
if !function_tool_names.contains(&name.as_str()) {
let mut e =
validator::ValidationError::new("tool_choice_tool_not_found");
e.message = Some(
format!(
"Invalid value for 'tool_choice.tools': tool '{}' not found in 'tools'.",
name
)
.into(),
);
return Err(e);
}
}
// Note: MCP and hosted tools don't need existence validation here
// as they are resolved dynamically at runtime
}
}
_ => {}
}
}
Ok(())
}
/// Normalize a SimpleInputMessage to a proper Message item /// Normalize a SimpleInputMessage to a proper Message item
/// ///
/// This helper converts SimpleInputMessage (which can have flexible content) /// This helper converts SimpleInputMessage (which can have flexible content)
......
...@@ -11,7 +11,10 @@ use serde_json::json; ...@@ -11,7 +11,10 @@ use serde_json::json;
use crate::{ use crate::{
core::WorkerRegistry, core::WorkerRegistry,
mcp::McpManager, mcp::McpManager,
protocols::responses::{ResponseTool, ResponseToolType}, protocols::{
common::Tool,
responses::{ResponseTool, ResponseToolType},
},
routers::{grpc::error, openai::mcp::ensure_request_mcp_client}, routers::{grpc::error, openai::mcp::ensure_request_mcp_client},
}; };
...@@ -76,3 +79,47 @@ pub fn validate_worker_availability( ...@@ -76,3 +79,47 @@ pub fn validate_worker_availability(
None None
} }
/// Extract function tools (and optionally MCP tools) from ResponseTools
///
/// This utility consolidates the logic for extracting tools with schemas from ResponseTools.
/// It's used by both Harmony and Regular routers for different purposes:
///
/// - **Harmony router**: Extracts both Function and MCP tools (with `include_mcp: true`)
/// because MCP schemas are populated by convert_mcp_tools_to_response_tools() before the
/// pipeline runs. These tools are used to generate structural constraints in the
/// Harmony preparation stage.
///
/// - **Regular router**: Extracts only Function tools (with `include_mcp: false`) during
/// the initial conversion from ResponsesRequest to ChatCompletionRequest. MCP tools
/// are merged later by the tool loop before being sent to the chat pipeline, where
/// tool_choice constraints are generated for ALL tools (function + MCP combined).
pub fn extract_tools_from_response_tools(
response_tools: Option<&[ResponseTool]>,
include_mcp: bool,
) -> Vec<Tool> {
let Some(tools) = response_tools else {
return Vec::new();
};
tools
.iter()
.filter_map(|rt| {
match rt.r#type {
// Function tools: Schema in request
ResponseToolType::Function => rt.function.as_ref().map(|f| Tool {
tool_type: "function".to_string(),
function: f.clone(),
}),
// MCP tools: Schema populated by convert_mcp_tools_to_response_tools()
// Only include if requested (Harmony case)
ResponseToolType::Mcp if include_mcp => rt.function.as_ref().map(|f| Tool {
tool_type: "function".to_string(),
function: f.clone(),
}),
// Hosted tools: No schema available, skip
_ => None,
}
})
.collect()
}
...@@ -52,7 +52,7 @@ use crate::{ ...@@ -52,7 +52,7 @@ use crate::{
data_connector::{ResponseId, ResponseStorage}, data_connector::{ResponseId, ResponseStorage},
mcp::{self, McpManager}, mcp::{self, McpManager},
protocols::{ protocols::{
common::{Function, ToolCall, Usage}, common::{Function, ToolCall, ToolChoice, ToolChoiceValue, Usage},
responses::{ responses::{
McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem, McpToolInfo, ResponseContentPart, ResponseInput, ResponseInputOutputItem,
ResponseOutputItem, ResponseReasoningContent, ResponseStatus, ResponseTool, ResponseOutputItem, ResponseReasoningContent, ResponseStatus, ResponseTool,
...@@ -467,15 +467,6 @@ async fn execute_without_mcp_loop( ...@@ -467,15 +467,6 @@ async fn execute_without_mcp_loop(
/// - Calls `streaming::process_responses_iteration_stream()` for per-iteration events /// - Calls `streaming::process_responses_iteration_stream()` for per-iteration events
/// - Emits `response.completed` at end /// - Emits `response.completed` at end
/// - Handles errors with `response.failed` /// - Handles errors with `response.failed`
///
/// # Arguments
///
/// * `ctx` - Harmony responses context with pipeline and dependencies
/// * `request` - Responses API request
///
/// # Returns
///
/// SSE stream response with proper headers
pub async fn serve_harmony_responses_stream( pub async fn serve_harmony_responses_stream(
ctx: &HarmonyResponsesContext, ctx: &HarmonyResponsesContext,
request: ResponsesRequest, request: ResponsesRequest,
...@@ -1189,6 +1180,11 @@ fn build_next_request_with_tools( ...@@ -1189,6 +1180,11 @@ fn build_next_request_with_tools(
// Update request with new items // Update request with new items
request.input = ResponseInput::Items(items); request.input = ResponseInput::Items(items);
// Switch tool_choice to "auto" for subsequent iterations
// This prevents infinite loops when original tool_choice was "required" or specific function
// After receiving tool results, the model should be free to decide whether to call more tools or finish
request.tool_choice = Some(ToolChoice::Value(ToolChoiceValue::Auto));
Ok(request) Ok(request)
} }
...@@ -1214,14 +1210,6 @@ struct ToolResult { ...@@ -1214,14 +1210,6 @@ struct ToolResult {
/// ///
/// Converts MCP Tool entries (from rmcp SDK) to ResponseTool format so the model /// Converts MCP Tool entries (from rmcp SDK) to ResponseTool format so the model
/// knows about available MCP tools when making tool calls. /// knows about available MCP tools when making tool calls.
///
/// # Arguments
///
/// * `mcp_tools` - MCP tools from the MCP manager inventory (rmcp::model::Tool)
///
/// # Returns
///
/// Vector of ResponseTool entries in MCP format
pub fn convert_mcp_tools_to_response_tools(mcp_tools: &[mcp::Tool]) -> Vec<ResponseTool> { pub fn convert_mcp_tools_to_response_tools(mcp_tools: &[mcp::Tool]) -> Vec<ResponseTool> {
mcp_tools mcp_tools
.iter() .iter()
......
...@@ -12,7 +12,7 @@ use crate::{ ...@@ -12,7 +12,7 @@ use crate::{
responses::ResponsesRequest, responses::ResponsesRequest,
}, },
routers::grpc::{ routers::grpc::{
common::stages::PipelineStage, common::{responses::utils::extract_tools_from_response_tools, stages::PipelineStage},
context::{PreparationOutput, RequestContext, RequestType}, context::{PreparationOutput, RequestContext, RequestType},
error, utils, error, utils,
}, },
...@@ -84,7 +84,7 @@ impl HarmonyPreparationStage { ...@@ -84,7 +84,7 @@ impl HarmonyPreparationStage {
} }
// Step 1: Filter tools if needed // Step 1: Filter tools if needed
let body_ref = utils::filter_tools_for_request(request); let body_ref = utils::filter_chat_request_by_tool_choice(request);
// Step 2: Build tool constraints // Step 2: Build tool constraints
let tool_constraints = if let Some(tools) = body_ref.tools.as_ref() { let tool_constraints = if let Some(tools) = body_ref.tools.as_ref() {
...@@ -128,18 +128,37 @@ impl HarmonyPreparationStage { ...@@ -128,18 +128,37 @@ impl HarmonyPreparationStage {
ctx: &mut RequestContext, ctx: &mut RequestContext,
request: &ResponsesRequest, request: &ResponsesRequest,
) -> Result<Option<Response>, Response> { ) -> Result<Option<Response>, Response> {
// Build via Harmony from responses API request // Step 1: Extract function and MCP tools with schemas from ResponseTools
let mut function_tools = extract_tools_from_response_tools(request.tools.as_deref(), true);
// Step 2: Filter tools based on tool_choice (AllowedTools or Function)
// Note: Tool existence is already validated in ResponsesRequest::validate()
if let Some(filtered) =
utils::filter_tools_by_tool_choice(&function_tools, &request.tool_choice)
{
function_tools = filtered;
}
// Step 3: Generate Harmony structural tags from filtered tools
let tool_constraints = if !function_tools.is_empty() {
Self::generate_harmony_structural_tag(&function_tools, &request.tool_choice)
.map_err(|e| *e)?
} else {
None
};
// Step 3: Build via Harmony from responses API request
let build_output = self let build_output = self
.builder .builder
.build_from_responses(request) .build_from_responses(request)
.map_err(|e| error::bad_request(format!("Harmony build failed: {}", e)))?; .map_err(|e| error::bad_request(format!("Harmony build failed: {}", e)))?;
// Store results in preparation output // Step 4: Store results with tool_constraints
ctx.state.preparation = Some(PreparationOutput { ctx.state.preparation = Some(PreparationOutput {
original_text: None, original_text: None,
token_ids: build_output.input_ids, token_ids: build_output.input_ids,
processed_messages: None, processed_messages: None,
tool_constraints: None, tool_constraints,
filtered_request: None, filtered_request: None,
harmony_mode: true, harmony_mode: true,
selection_text: Some(build_output.selection_text), selection_text: Some(build_output.selection_text),
......
...@@ -84,6 +84,7 @@ impl PipelineStage for HarmonyRequestBuildingStage { ...@@ -84,6 +84,7 @@ impl PipelineStage for HarmonyRequestBuildingStage {
placeholder_processed_text, placeholder_processed_text,
prep.token_ids.clone(), prep.token_ids.clone(),
prep.harmony_stop_ids.clone(), prep.harmony_stop_ids.clone(),
prep.tool_constraints.clone(),
) )
.map_err(|e| error::bad_request(format!("Invalid request parameters: {}", e)))?, .map_err(|e| error::bad_request(format!("Invalid request parameters: {}", e)))?,
_ => unreachable!(), _ => unreachable!(),
......
...@@ -7,14 +7,17 @@ ...@@ -7,14 +7,17 @@
//! This allows the gRPC router to reuse the existing chat pipeline infrastructure //! This allows the gRPC router to reuse the existing chat pipeline infrastructure
//! without requiring Python backend changes. //! without requiring Python backend changes.
use crate::protocols::{ use crate::{
chat::{ChatCompletionRequest, ChatCompletionResponse, ChatMessage, UserMessageContent}, protocols::{
common::{FunctionCallResponse, StreamOptions, ToolCall, UsageInfo}, chat::{ChatCompletionRequest, ChatCompletionResponse, ChatMessage, UserMessageContent},
responses::{ common::{FunctionCallResponse, StreamOptions, ToolCall, ToolChoice, UsageInfo},
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem, responses::{
ResponseReasoningContent::ReasoningText, ResponseStatus, ResponsesRequest, ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
ResponsesResponse, ResponsesUsage, StringOrContentParts, ResponseReasoningContent::ReasoningText, ResponseStatus, ResponsesRequest,
ResponsesResponse, ResponsesUsage, StringOrContentParts,
},
}, },
routers::grpc::common::responses::utils::extract_tools_from_response_tools,
}; };
/// Convert a ResponsesRequest to ChatCompletionRequest for processing through the chat pipeline /// Convert a ResponsesRequest to ChatCompletionRequest for processing through the chat pipeline
...@@ -23,7 +26,8 @@ use crate::protocols::{ ...@@ -23,7 +26,8 @@ use crate::protocols::{
/// - `input` (text/items) → `messages` (chat messages) /// - `input` (text/items) → `messages` (chat messages)
/// - `instructions` → system message (prepended) /// - `instructions` → system message (prepended)
/// - `max_output_tokens` → `max_completion_tokens` /// - `max_output_tokens` → `max_completion_tokens`
/// - Tool-related fields are passed through /// - `tools` → function tools extracted from ResponseTools
/// - `tool_choice` → passed through from request
/// - Response-specific fields (previous_response_id, conversation) are handled by router /// - Response-specific fields (previous_response_id, conversation) are handled by router
pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest, String> { pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest, String> {
let mut messages = Vec::new(); let mut messages = Vec::new();
...@@ -68,69 +72,13 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest ...@@ -68,69 +72,13 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
} }
}; };
match role.as_str() { messages.push(role_to_chat_message(role.as_str(), text));
"user" => {
messages.push(ChatMessage::User {
content: UserMessageContent::Text(text),
name: None,
});
}
"assistant" => {
messages.push(ChatMessage::Assistant {
content: Some(text),
name: None,
tool_calls: None,
reasoning_content: None,
});
}
"system" => {
messages.push(ChatMessage::System {
content: text,
name: None,
});
}
_ => {
// Unknown role, treat as user message
messages.push(ChatMessage::User {
content: UserMessageContent::Text(text),
name: None,
});
}
}
} }
ResponseInputOutputItem::Message { role, content, .. } => { ResponseInputOutputItem::Message { role, content, .. } => {
// Extract text from content parts // Extract text from content parts
let text = extract_text_from_content(content); let text = extract_text_from_content(content);
match role.as_str() { messages.push(role_to_chat_message(role.as_str(), text));
"user" => {
messages.push(ChatMessage::User {
content: UserMessageContent::Text(text),
name: None,
});
}
"assistant" => {
messages.push(ChatMessage::Assistant {
content: Some(text),
name: None,
tool_calls: None,
reasoning_content: None,
});
}
"system" => {
messages.push(ChatMessage::System {
content: text,
name: None,
});
}
_ => {
// Unknown role, treat as user message
messages.push(ChatMessage::User {
content: UserMessageContent::Text(text),
name: None,
});
}
}
} }
ResponseInputOutputItem::FunctionToolCall { ResponseInputOutputItem::FunctionToolCall {
id, id,
...@@ -203,7 +151,18 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest ...@@ -203,7 +151,18 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
return Err("Request must contain at least one message".to_string()); return Err("Request must contain at least one message".to_string());
} }
// 3. Build ChatCompletionRequest // 3. Extract function tools from ResponseTools
// Only function tools are extracted here (include_mcp: false).
// MCP tools are merged later by the tool loop (see tool_loop.rs:prepare_chat_tools_and_choice)
// before the chat pipeline, where tool_choice constraints are applied to ALL tools combined.
let function_tools = extract_tools_from_response_tools(req.tools.as_deref(), false);
let tools = if function_tools.is_empty() {
None
} else {
Some(function_tools)
};
// 4. Build ChatCompletionRequest
let is_streaming = req.stream.unwrap_or(false); let is_streaming = req.stream.unwrap_or(false);
Ok(ChatCompletionRequest { Ok(ChatCompletionRequest {
...@@ -227,9 +186,8 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest ...@@ -227,9 +186,8 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
top_logprobs: req.top_logprobs, top_logprobs: req.top_logprobs,
top_p: req.top_p, top_p: req.top_p,
skip_special_tokens: true, skip_special_tokens: true,
// Note: tools and tool_choice will be handled separately for MCP transformation tools,
tools: None, // Will be set by caller if needed tool_choice: req.tool_choice.clone(),
tool_choice: None, // Will be set by caller if needed
..Default::default() ..Default::default()
}) })
} }
...@@ -247,6 +205,33 @@ fn extract_text_from_content(content: &[ResponseContentPart]) -> String { ...@@ -247,6 +205,33 @@ fn extract_text_from_content(content: &[ResponseContentPart]) -> String {
.join("") .join("")
} }
/// Convert role and text to ChatMessage
fn role_to_chat_message(role: &str, text: String) -> ChatMessage {
match role {
"user" => ChatMessage::User {
content: UserMessageContent::Text(text),
name: None,
},
"assistant" => ChatMessage::Assistant {
content: Some(text),
name: None,
tool_calls: None,
reasoning_content: None,
},
"system" => ChatMessage::System {
content: text,
name: None,
},
_ => {
// Unknown role, treat as user message
ChatMessage::User {
content: UserMessageContent::Text(text),
name: None,
}
}
}
}
/// Convert a ChatCompletionResponse to ResponsesResponse /// Convert a ChatCompletionResponse to ResponsesResponse
/// ///
/// # Conversion Logic /// # Conversion Logic
...@@ -354,7 +339,7 @@ pub fn chat_to_responses( ...@@ -354,7 +339,7 @@ pub fn chat_to_responses(
store: original_req.store.unwrap_or(true), store: original_req.store.unwrap_or(true),
temperature: original_req.temperature, temperature: original_req.temperature,
text: None, text: None,
tool_choice: "auto".to_string(), // TODO: Map from original request tool_choice: ToolChoice::serialize_to_string(&original_req.tool_choice),
tools: original_req.tools.clone().unwrap_or_default(), tools: original_req.tools.clone().unwrap_or_default(),
top_p: original_req.top_p, top_p: original_req.top_p,
truncation: None, truncation: None,
......
...@@ -58,7 +58,7 @@ use crate::{ ...@@ -58,7 +58,7 @@ use crate::{
}, },
protocols::{ protocols::{
chat::{self, ChatCompletionStreamResponse}, chat::{self, ChatCompletionStreamResponse},
common, common::{self, ToolChoice},
responses::{ responses::{
self, ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem, self, ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
ResponseReasoningContent, ResponseStatus, ResponsesRequest, ResponsesResponse, ResponseReasoningContent, ResponseStatus, ResponsesRequest, ResponsesResponse,
...@@ -657,7 +657,7 @@ impl StreamingResponseAccumulator { ...@@ -657,7 +657,7 @@ impl StreamingResponseAccumulator {
store: self.original_request.store.unwrap_or(true), store: self.original_request.store.unwrap_or(true),
temperature: self.original_request.temperature, temperature: self.original_request.temperature,
text: None, text: None,
tool_choice: "auto".to_string(), tool_choice: ToolChoice::serialize_to_string(&self.original_request.tool_choice),
tools: self.original_request.tools.clone().unwrap_or_default(), tools: self.original_request.tools.clone().unwrap_or_default(),
top_p: self.original_request.top_p, top_p: self.original_request.top_p,
truncation: None, truncation: None,
......
...@@ -13,7 +13,7 @@ use axum::{ ...@@ -13,7 +13,7 @@ use axum::{
}; };
use bytes::Bytes; use bytes::Bytes;
use futures_util::StreamExt; use futures_util::StreamExt;
use serde_json::json; use serde_json::{json, Value};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, warn}; use tracing::{debug, warn};
...@@ -24,7 +24,8 @@ use crate::{ ...@@ -24,7 +24,8 @@ use crate::{
mcp::{self, McpManager}, mcp::{self, McpManager},
protocols::{ protocols::{
chat::{ chat::{
ChatChoice, ChatCompletionMessage, ChatCompletionResponse, ChatCompletionStreamResponse, ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionStreamResponse,
}, },
common::{Function, FunctionCallResponse, Tool, ToolCall, ToolChoice, ToolChoiceValue}, common::{Function, FunctionCallResponse, Tool, ToolCall, ToolChoice, ToolChoiceValue},
responses::{ responses::{
...@@ -66,6 +67,30 @@ fn extract_function_call_from_chat( ...@@ -66,6 +67,30 @@ fn extract_function_call_from_chat(
None None
} }
/// Merge function tools from request with MCP tools and set tool_choice based on iteration
fn prepare_chat_tools_and_choice(
chat_request: &mut ChatCompletionRequest,
mcp_chat_tools: &[Tool],
iteration: usize,
) {
// Merge function tools from request with MCP tools
let mut all_tools = chat_request.tools.clone().unwrap_or_default();
all_tools.extend(mcp_chat_tools.iter().cloned());
chat_request.tools = Some(all_tools);
// Set tool_choice based on iteration
// - Iteration 0: Use user's tool_choice or default to auto
// - Iteration 1+: Always use auto to avoid infinite loops
chat_request.tool_choice = if iteration == 0 {
chat_request
.tool_choice
.clone()
.or(Some(ToolChoice::Value(ToolChoiceValue::Auto)))
} else {
Some(ToolChoice::Value(ToolChoiceValue::Auto))
};
}
/// Extract all tool calls from chat response (for parallel tool call support) /// Extract all tool calls from chat response (for parallel tool call support)
fn extract_all_tool_calls_from_chat( fn extract_all_tool_calls_from_chat(
response: &ChatCompletionResponse, response: &ChatCompletionResponse,
...@@ -166,16 +191,13 @@ fn build_mcp_list_tools_item(mcp: &Arc<McpManager>, server_label: &str) -> Respo ...@@ -166,16 +191,13 @@ fn build_mcp_list_tools_item(mcp: &Arc<McpManager>, server_label: &str) -> Respo
let tools = mcp.list_tools(); let tools = mcp.list_tools();
let tools_info: Vec<McpToolInfo> = tools let tools_info: Vec<McpToolInfo> = tools
.iter() .iter()
.map(|t| { .map(|t| McpToolInfo {
use serde_json::Value; name: t.name.to_string(),
McpToolInfo { description: t.description.as_ref().map(|d| d.to_string()),
name: t.name.to_string(), input_schema: Value::Object((*t.input_schema).clone()),
description: t.description.as_ref().map(|d| d.to_string()), annotations: Some(json!({
input_schema: Value::Object((*t.input_schema).clone()), "read_only": false
annotations: Some(json!({ })),
"read_only": false
})),
}
}) })
.collect(); .collect();
...@@ -247,17 +269,19 @@ pub(super) async fn execute_tool_loop( ...@@ -247,17 +269,19 @@ pub(super) async fn execute_tool_loop(
// Get MCP tools and convert to chat format (do this once before loop) // Get MCP tools and convert to chat format (do this once before loop)
let mcp_tools = ctx.mcp_manager.list_tools(); let mcp_tools = ctx.mcp_manager.list_tools();
let chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools); let mcp_chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools);
debug!("Converted {} MCP tools to chat format", chat_tools.len()); debug!(
"Converted {} MCP tools to chat format",
mcp_chat_tools.len()
);
loop { loop {
// Convert to chat request // Convert to chat request
let mut chat_request = conversions::responses_to_chat(&current_request) let mut chat_request = conversions::responses_to_chat(&current_request)
.map_err(|e| error::bad_request(format!("Failed to convert request: {}", e)))?; .map_err(|e| error::bad_request(format!("Failed to convert request: {}", e)))?;
// Add MCP tools to chat request so LLM knows about them // Prepare tools and tool_choice for this iteration
chat_request.tools = Some(chat_tools.clone()); prepare_chat_tools_and_choice(&mut chat_request, &mcp_chat_tools, state.iteration);
chat_request.tool_choice = Some(ToolChoice::Value(ToolChoiceValue::Auto));
// Execute chat pipeline (errors already have proper HTTP status codes) // Execute chat pipeline (errors already have proper HTTP status codes)
let chat_response = ctx let chat_response = ctx
...@@ -555,10 +579,10 @@ async fn execute_tool_loop_streaming_internal( ...@@ -555,10 +579,10 @@ async fn execute_tool_loop_streaming_internal(
// Get MCP tools and convert to chat format (do this once before loop) // Get MCP tools and convert to chat format (do this once before loop)
let mcp_tools = ctx.mcp_manager.list_tools(); let mcp_tools = ctx.mcp_manager.list_tools();
let chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools); let mcp_chat_tools = convert_mcp_tools_to_chat_tools(&mcp_tools);
debug!( debug!(
"Streaming: Converted {} MCP tools to chat format", "Streaming: Converted {} MCP tools to chat format",
chat_tools.len() mcp_chat_tools.len()
); );
// Flag to track if mcp_list_tools has been emitted // Flag to track if mcp_list_tools has been emitted
...@@ -584,7 +608,6 @@ async fn execute_tool_loop_streaming_internal( ...@@ -584,7 +608,6 @@ async fn execute_tool_loop_streaming_internal(
let tool_items: Vec<_> = mcp_tools let tool_items: Vec<_> = mcp_tools
.iter() .iter()
.map(|t| { .map(|t| {
use serde_json::Value;
json!({ json!({
"name": t.name, "name": t.name,
"description": t.description, "description": t.description,
...@@ -635,9 +658,8 @@ async fn execute_tool_loop_streaming_internal( ...@@ -635,9 +658,8 @@ async fn execute_tool_loop_streaming_internal(
let mut chat_request = conversions::responses_to_chat(&current_request) let mut chat_request = conversions::responses_to_chat(&current_request)
.map_err(|e| format!("Failed to convert request: {}", e))?; .map_err(|e| format!("Failed to convert request: {}", e))?;
// Add MCP tools to chat request so LLM knows about them // Prepare tools and tool_choice for this iteration (same logic as non-streaming)
chat_request.tools = Some(chat_tools.clone()); prepare_chat_tools_and_choice(&mut chat_request, &mcp_chat_tools, state.iteration);
chat_request.tool_choice = Some(ToolChoice::Value(ToolChoiceValue::Auto));
// Execute chat streaming // Execute chat streaming
let response = ctx let response = ctx
...@@ -913,7 +935,6 @@ async fn execute_tool_loop_streaming_internal( ...@@ -913,7 +935,6 @@ async fn execute_tool_loop_streaming_internal(
/// Convert MCP tools to Chat API tool format /// Convert MCP tools to Chat API tool format
fn convert_mcp_tools_to_chat_tools(mcp_tools: &[mcp::Tool]) -> Vec<Tool> { fn convert_mcp_tools_to_chat_tools(mcp_tools: &[mcp::Tool]) -> Vec<Tool> {
use serde_json::Value;
mcp_tools mcp_tools
.iter() .iter()
.map(|tool_info| Tool { .map(|tool_info| Tool {
......
...@@ -40,7 +40,7 @@ impl ChatPreparationStage { ...@@ -40,7 +40,7 @@ impl ChatPreparationStage {
request: &ChatCompletionRequest, request: &ChatCompletionRequest,
) -> Result<(), Response> { ) -> Result<(), Response> {
// Step 1: Filter tools if needed // Step 1: Filter tools if needed
let body_ref = utils::filter_tools_for_request(request); let body_ref = utils::filter_chat_request_by_tool_choice(request);
// Step 2: Process messages and apply chat template // Step 2: Process messages and apply chat template
let processed_messages = let processed_messages =
......
...@@ -9,7 +9,6 @@ use tracing::{error, warn}; ...@@ -9,7 +9,6 @@ use tracing::{error, warn};
use uuid::Uuid; use uuid::Uuid;
use super::{error, ProcessedMessages}; use super::{error, ProcessedMessages};
pub use crate::tokenizer::StopSequenceDecoder;
use crate::{ use crate::{
core::Worker, core::Worker,
grpc_client::{proto, sglang_scheduler::AbortOnDropStream, SglangSchedulerClient}, grpc_client::{proto, sglang_scheduler::AbortOnDropStream, SglangSchedulerClient},
...@@ -28,8 +27,9 @@ use crate::{ ...@@ -28,8 +27,9 @@ use crate::{
tokenizer::{ tokenizer::{
cache::CachedTokenizer, cache::CachedTokenizer,
chat_template::{ChatTemplateContentFormat, ChatTemplateParams}, chat_template::{ChatTemplateContentFormat, ChatTemplateParams},
stop::StopSequenceDecoderBuilder,
traits::Tokenizer, traits::Tokenizer,
HuggingFaceTokenizer, HuggingFaceTokenizer, StopSequenceDecoder,
}, },
tool_parser::{ tool_parser::{
ParserFactory as ToolParserFactory, PooledParser as ToolPooledParser, ToolParser, ParserFactory as ToolParserFactory, PooledParser as ToolPooledParser, ToolParser,
...@@ -273,39 +273,57 @@ fn build_required_array_schema(tools: &[Tool]) -> Result<String, String> { ...@@ -273,39 +273,57 @@ fn build_required_array_schema(tools: &[Tool]) -> Result<String, String> {
.map_err(|e| format!("Failed to serialize tool schema: {}", e)) .map_err(|e| format!("Failed to serialize tool schema: {}", e))
} }
/// Filter tools based on tool_choice (shared by both routers) /// Filter tools based on tool_choice (generic helper)
/// Returns a reference to the original body if no filtering needed, ///
/// otherwise returns a cloned and filtered body /// Returns filtered tools if filtering is needed, otherwise returns None.
pub fn filter_tools_for_request( /// Used by both Chat API and Responses API (Harmony) for constraint generation.
body: &ChatCompletionRequest, pub fn filter_tools_by_tool_choice(
) -> std::borrow::Cow<'_, ChatCompletionRequest> { tools: &[Tool],
match &body.tool_choice { tool_choice: &Option<ToolChoice>,
Some(ToolChoice::AllowedTools { tools: allowed, .. }) if body.tools.is_some() => { ) -> Option<Vec<Tool>> {
let mut filtered_body = body.clone(); match tool_choice {
let all_tools = filtered_body.tools.as_ref().unwrap(); Some(ToolChoice::AllowedTools { tools: allowed, .. }) => {
let allowed_names: std::collections::HashSet<&str> = let allowed_names: std::collections::HashSet<&str> =
allowed.iter().map(|t| t.name.as_str()).collect(); allowed.iter().filter_map(|t| t.function_name()).collect();
let filtered_tools: Vec<Tool> = all_tools let filtered: Vec<Tool> = tools
.iter() .iter()
.filter(|t| allowed_names.contains(t.function.name.as_str())) .filter(|t| allowed_names.contains(t.function.name.as_str()))
.cloned() .cloned()
.collect(); .collect();
filtered_body.tools = Some(filtered_tools); Some(filtered)
std::borrow::Cow::Owned(filtered_body)
} }
Some(ToolChoice::Function { function, .. }) if body.tools.is_some() => { Some(ToolChoice::Function { function, .. }) => {
let mut filtered_body = body.clone(); let filtered: Vec<Tool> = tools
let all_tools = filtered_body.tools.as_ref().unwrap();
let filtered_tools: Vec<Tool> = all_tools
.iter() .iter()
.filter(|t| t.function.name == function.name) .filter(|t| t.function.name == function.name)
.cloned() .cloned()
.collect(); .collect();
Some(filtered)
}
_ => None, // No filtering needed
}
}
/// Filter ChatCompletionRequest by tool_choice
///
/// Returns a reference to the original request if no filtering needed,
/// otherwise returns a cloned request with filtered tools.
///
/// Note: Tool existence is validated earlier in ChatCompletionRequest::validate(),
/// so this function assumes tool_choice references valid tools.
pub fn filter_chat_request_by_tool_choice(
body: &ChatCompletionRequest,
) -> std::borrow::Cow<'_, ChatCompletionRequest> {
if let Some(tools) = &body.tools {
if let Some(filtered_tools) = filter_tools_by_tool_choice(tools, &body.tool_choice) {
let mut filtered_body = body.clone();
filtered_body.tools = Some(filtered_tools); filtered_body.tools = Some(filtered_tools);
std::borrow::Cow::Owned(filtered_body) return std::borrow::Cow::Owned(filtered_body);
} }
_ => std::borrow::Cow::Borrowed(body), // No filtering needed, use original
} }
// No filtering needed - return original request
std::borrow::Cow::Borrowed(body)
} }
/// Process chat messages and apply template (shared by both routers) /// Process chat messages and apply template (shared by both routers)
...@@ -438,8 +456,6 @@ pub fn create_stop_decoder( ...@@ -438,8 +456,6 @@ pub fn create_stop_decoder(
skip_special_tokens: bool, skip_special_tokens: bool,
no_stop_trim: bool, no_stop_trim: bool,
) -> StopSequenceDecoder { ) -> StopSequenceDecoder {
use crate::tokenizer::stop::StopSequenceDecoderBuilder;
// Extract stop sequences // Extract stop sequences
let stop_sequences: Vec<String> = match stop { let stop_sequences: Vec<String> = match stop {
Some(StringOrArray::String(s)) => vec![s.clone()], Some(StringOrArray::String(s)) => vec![s.clone()],
......
...@@ -349,8 +349,7 @@ fn test_tool_choice_allowed_tools_invalid_mode() { ...@@ -349,8 +349,7 @@ fn test_tool_choice_allowed_tools_invalid_mode() {
}]), }]),
tool_choice: Some(ToolChoice::AllowedTools { tool_choice: Some(ToolChoice::AllowedTools {
mode: "invalid_mode".to_string(), mode: "invalid_mode".to_string(),
tools: vec![ToolReference { tools: vec![ToolReference::Function {
tool_type: "function".to_string(),
name: "get_weather".to_string(), name: "get_weather".to_string(),
}], }],
tool_type: "function".to_string(), tool_type: "function".to_string(),
...@@ -387,8 +386,7 @@ fn test_tool_choice_allowed_tools_valid_mode_auto() { ...@@ -387,8 +386,7 @@ fn test_tool_choice_allowed_tools_valid_mode_auto() {
}]), }]),
tool_choice: Some(ToolChoice::AllowedTools { tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(), mode: "auto".to_string(),
tools: vec![ToolReference { tools: vec![ToolReference::Function {
tool_type: "function".to_string(),
name: "get_weather".to_string(), name: "get_weather".to_string(),
}], }],
tool_type: "function".to_string(), tool_type: "function".to_string(),
...@@ -419,8 +417,7 @@ fn test_tool_choice_allowed_tools_valid_mode_required() { ...@@ -419,8 +417,7 @@ fn test_tool_choice_allowed_tools_valid_mode_required() {
}]), }]),
tool_choice: Some(ToolChoice::AllowedTools { tool_choice: Some(ToolChoice::AllowedTools {
mode: "required".to_string(), mode: "required".to_string(),
tools: vec![ToolReference { tools: vec![ToolReference::Function {
tool_type: "function".to_string(),
name: "get_weather".to_string(), name: "get_weather".to_string(),
}], }],
tool_type: "function".to_string(), tool_type: "function".to_string(),
...@@ -451,8 +448,7 @@ fn test_tool_choice_allowed_tools_tool_not_found() { ...@@ -451,8 +448,7 @@ fn test_tool_choice_allowed_tools_tool_not_found() {
}]), }]),
tool_choice: Some(ToolChoice::AllowedTools { tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(), mode: "auto".to_string(),
tools: vec![ToolReference { tools: vec![ToolReference::Function {
tool_type: "function".to_string(),
name: "nonexistent_tool".to_string(), name: "nonexistent_tool".to_string(),
}], }],
tool_type: "function".to_string(), tool_type: "function".to_string(),
...@@ -501,12 +497,10 @@ fn test_tool_choice_allowed_tools_multiple_tools_valid() { ...@@ -501,12 +497,10 @@ fn test_tool_choice_allowed_tools_multiple_tools_valid() {
tool_choice: Some(ToolChoice::AllowedTools { tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(), mode: "auto".to_string(),
tools: vec![ tools: vec![
ToolReference { ToolReference::Function {
tool_type: "function".to_string(),
name: "get_weather".to_string(), name: "get_weather".to_string(),
}, },
ToolReference { ToolReference::Function {
tool_type: "function".to_string(),
name: "get_time".to_string(), name: "get_time".to_string(),
}, },
], ],
...@@ -550,12 +544,10 @@ fn test_tool_choice_allowed_tools_one_invalid_among_valid() { ...@@ -550,12 +544,10 @@ fn test_tool_choice_allowed_tools_one_invalid_among_valid() {
tool_choice: Some(ToolChoice::AllowedTools { tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(), mode: "auto".to_string(),
tools: vec![ tools: vec![
ToolReference { ToolReference::Function {
tool_type: "function".to_string(),
name: "get_weather".to_string(), name: "get_weather".to_string(),
}, },
ToolReference { ToolReference::Function {
tool_type: "function".to_string(),
name: "nonexistent_tool".to_string(), name: "nonexistent_tool".to_string(),
}, },
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment