Unverified Commit 963175d5 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support streaming for v1/chat/completions (#11179)

parent 0618ad6d
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use uuid;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
...@@ -84,16 +83,7 @@ impl LlamaParser { ...@@ -84,16 +83,7 @@ impl LlamaParser {
let arguments = serde_json::to_string(parameters) let arguments = serde_json::to_string(parameters)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate a unique ID for Llama calls
let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("llama_call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall { Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: name.to_string(), name: name.to_string(),
arguments, arguments,
...@@ -243,4 +233,8 @@ impl ToolParser for LlamaParser { ...@@ -243,4 +233,8 @@ impl ToolParser for LlamaParser {
text.contains("<|python_tag|>") text.contains("<|python_tag|>")
|| (text.trim_start().starts_with('{') && text.contains(r#""name""#)) || (text.trim_start().starts_with('{') && text.contains(r#""name""#))
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
...@@ -146,16 +146,7 @@ impl MistralParser { ...@@ -146,16 +146,7 @@ impl MistralParser {
let arguments = serde_json::to_string(args) let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate unique ID
let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("mistral_call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall { Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: name.to_string(), name: name.to_string(),
arguments, arguments,
...@@ -266,4 +257,8 @@ impl ToolParser for MistralParser { ...@@ -266,4 +257,8 @@ impl ToolParser for MistralParser {
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) self.has_tool_markers(text)
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
...@@ -244,7 +244,7 @@ fn parse_python_expression(source: &str) -> ToolParserResult<Expr> { ...@@ -244,7 +244,7 @@ fn parse_python_expression(source: &str) -> ToolParserResult<Expr> {
} }
} }
fn build_tool_call(expr: Expr, index: usize) -> ToolParserResult<ToolCall> { fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> {
match expr { match expr {
Expr::Call(call_expr) => { Expr::Call(call_expr) => {
if !call_expr.args.is_empty() { if !call_expr.args.is_empty() {
...@@ -277,8 +277,6 @@ fn build_tool_call(expr: Expr, index: usize) -> ToolParserResult<ToolCall> { ...@@ -277,8 +277,6 @@ fn build_tool_call(expr: Expr, index: usize) -> ToolParserResult<ToolCall> {
let arguments_string = serde_json::to_string(&arguments_json)?; let arguments_string = serde_json::to_string(&arguments_json)?;
Ok(ToolCall { Ok(ToolCall {
id: format!("call-{}", index + 1),
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: function_name, name: function_name,
arguments: arguments_string, arguments: arguments_string,
......
...@@ -88,16 +88,7 @@ impl QwenParser { ...@@ -88,16 +88,7 @@ impl QwenParser {
let arguments = serde_json::to_string(args) let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate unique ID
let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("qwen_call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall { Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: name.to_string(), name: name.to_string(),
arguments, arguments,
...@@ -255,4 +246,8 @@ impl ToolParser for QwenParser { ...@@ -255,4 +246,8 @@ impl ToolParser for QwenParser {
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) self.has_tool_markers(text)
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
...@@ -400,12 +400,7 @@ impl Step3Parser { ...@@ -400,12 +400,7 @@ impl Step3Parser {
let arguments_str = serde_json::to_string(&parameters) let arguments_str = serde_json::to_string(&parameters)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("step3_call_{}", uuid::Uuid::new_v4());
Ok(Some(ToolCall { Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: func_name.to_string(), name: func_name.to_string(),
arguments: arguments_str, arguments: arguments_str,
...@@ -561,4 +556,8 @@ impl ToolParser for Step3Parser { ...@@ -561,4 +556,8 @@ impl ToolParser for Step3Parser {
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) self.has_tool_markers(text)
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
...@@ -31,8 +31,6 @@ async fn test_tool_parser_factory_model_mapping() { ...@@ -31,8 +31,6 @@ async fn test_tool_parser_factory_model_mapping() {
#[test] #[test]
fn test_tool_call_serialization() { fn test_tool_call_serialization() {
let tool_call = ToolCall { let tool_call = ToolCall {
id: "call-123".to_string(),
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: "search".to_string(), name: "search".to_string(),
arguments: r#"{"query": "rust programming"}"#.to_string(), arguments: r#"{"query": "rust programming"}"#.to_string(),
...@@ -40,13 +38,15 @@ fn test_tool_call_serialization() { ...@@ -40,13 +38,15 @@ fn test_tool_call_serialization() {
}; };
let json = serde_json::to_string(&tool_call).unwrap(); let json = serde_json::to_string(&tool_call).unwrap();
assert!(json.contains("call-123"));
assert!(json.contains("search")); assert!(json.contains("search"));
assert!(json.contains("rust programming")); assert!(json.contains("rust programming"));
let parsed: ToolCall = serde_json::from_str(&json).unwrap(); let parsed: ToolCall = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "call-123");
assert_eq!(parsed.function.name, "search"); assert_eq!(parsed.function.name, "search");
assert_eq!(
parsed.function.arguments,
r#"{"query": "rust programming"}"#
);
} }
#[test] #[test]
......
...@@ -32,6 +32,12 @@ pub trait ToolParser: Send + Sync { ...@@ -32,6 +32,12 @@ pub trait ToolParser: Send + Sync {
fn as_token_parser(&self) -> Option<&dyn TokenToolParser> { fn as_token_parser(&self) -> Option<&dyn TokenToolParser> {
None None
} }
/// Get unstreamed tool call arguments
/// Returns tool call items for arguments that have been parsed but not yet streamed
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
None
}
} }
/// Trait for partial JSON parsing /// Trait for partial JSON parsing
......
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Parsed tool call from model output (OpenAI format) /// Parsed tool call from model output
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall { pub struct ToolCall {
/// Unique identifier for the tool call
pub id: String,
/// Type of tool call (currently always "function")
#[serde(rename = "type")]
pub r#type: String,
/// Function call details /// Function call details
pub function: FunctionCall, pub function: FunctionCall,
} }
......
...@@ -181,7 +181,6 @@ fn test_chatml_template() { ...@@ -181,7 +181,6 @@ fn test_chatml_template() {
content: Some("Hi there!".to_string()), content: Some("Hi there!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
function_call: None,
reasoning_content: None, reasoning_content: None,
}, },
spec::ChatMessage::User { spec::ChatMessage::User {
......
...@@ -68,7 +68,6 @@ mod tests { ...@@ -68,7 +68,6 @@ mod tests {
content: Some("Hi there".to_string()), content: Some("Hi there".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
function_call: None,
reasoning_content: None, reasoning_content: None,
}, },
]; ];
...@@ -213,7 +212,6 @@ mod tests { ...@@ -213,7 +212,6 @@ mod tests {
content: Some("World".to_string()), content: Some("World".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
function_call: None,
reasoning_content: None, reasoning_content: None,
}, },
]; ];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment