"vscode:/vscode.git/clone" did not exist on "2fa94956f4e500bf5c42263124c758d8613ee05e"
Unverified Commit b658be6f authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support tool call parser in streaming (#11160)

parent 5e786cca
use crate::tool_parser::parsers::{
DeepSeekParser, Glm4MoeParser, GptOssHarmonyParser, GptOssParser, JsonParser, KimiK2Parser,
LlamaParser, MistralParser, PythonicParser, QwenParser, Step3Parser,
};
use crate::tool_parser::traits::ToolParser;
use once_cell::sync::Lazy;
use std::{collections::HashMap, env, sync::Arc};
/// Global singleton registry instance - created once and reused
pub static GLOBAL_REGISTRY: Lazy<ParserRegistry> = Lazy::new(ParserRegistry::new_internal);
/// Registry for tool parsers and model mappings
pub struct ParserRegistry {
/// Map of parser name to parser instance
parsers: HashMap<String, Arc<dyn ToolParser>>,
/// Map of model name/pattern to parser name
model_mapping: HashMap<String, String>,
/// Default parser to use when no match found
default_parser: String,
}
impl ParserRegistry {
/// Get the global singleton instance
pub fn new() -> &'static Self {
&GLOBAL_REGISTRY
}
/// Create a new instance for testing (not the singleton)
#[cfg(test)]
pub fn new_for_testing() -> Self {
Self::new_internal()
}
/// Internal constructor for creating the singleton instance
fn new_internal() -> Self {
let mut registry = Self {
parsers: HashMap::new(),
model_mapping: HashMap::new(),
default_parser: "json".to_string(),
};
// Register default parsers
registry.register_default_parsers();
// Register default model mappings
registry.register_default_mappings();
registry
}
/// Register a parser
pub fn register_parser(&mut self, name: impl Into<String>, parser: Arc<dyn ToolParser>) {
self.parsers.insert(name.into(), parser);
}
/// Map a model name/pattern to a parser
pub fn map_model(&mut self, model: impl Into<String>, parser: impl Into<String>) {
self.model_mapping.insert(model.into(), parser.into());
}
/// Get parser for a specific model
pub fn get_parser(&self, model: &str) -> Option<Arc<dyn ToolParser>> {
// Try exact match first
if let Some(parser_name) = self.model_mapping.get(model) {
if let Some(parser) = self.parsers.get(parser_name) {
return Some(parser.clone());
}
}
// Try prefix matching with more specific patterns first
// Collect all matching patterns and sort by specificity (longer = more specific)
let mut matches: Vec<(&String, &String)> = self
.model_mapping
.iter()
.filter(|(pattern, _)| {
if pattern.ends_with('*') {
let prefix = &pattern[..pattern.len() - 1];
model.starts_with(prefix)
} else {
false
}
})
.collect();
// Sort by pattern length in descending order (longer patterns are more specific)
matches.sort_by_key(|(pattern, _)| std::cmp::Reverse(pattern.len()));
// Return the first matching parser
for (_, parser_name) in matches {
if let Some(parser) = self.parsers.get(parser_name) {
return Some(parser.clone());
}
}
// Fall back to default parser if it exists
self.parsers.get(&self.default_parser).cloned()
}
/// List all registered parsers
pub fn list_parsers(&self) -> Vec<&str> {
self.parsers.keys().map(|s| s.as_str()).collect()
}
/// List all model mappings
pub fn list_mappings(&self) -> Vec<(&str, &str)> {
self.model_mapping
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect()
}
/// Register default parsers
fn register_default_parsers(&mut self) {
// JSON parser - most common format
self.register_parser("json", Arc::new(JsonParser::new()));
// Mistral parser - [TOOL_CALLS] [...] format
self.register_parser("mistral", Arc::new(MistralParser::new()));
// Qwen parser - <tool_call>...</tool_call> format
self.register_parser("qwen", Arc::new(QwenParser::new()));
// Pythonic parser - [func(arg=val)] format
self.register_parser("pythonic", Arc::new(PythonicParser::new()));
// Llama parser - <|python_tag|>{...} or plain JSON format
self.register_parser("llama", Arc::new(LlamaParser::new()));
// DeepSeek V3 parser - Unicode tokens with JSON blocks
self.register_parser("deepseek", Arc::new(DeepSeekParser::new()));
// GLM-4 MoE parser - XML-style key-value format
self.register_parser("glm4_moe", Arc::new(Glm4MoeParser::new()));
// Step3 parser - StepTML XML format
self.register_parser("step3", Arc::new(Step3Parser::new()));
// Kimi K2 parser - Token-based with indexed functions
self.register_parser("kimik2", Arc::new(KimiK2Parser::new()));
// GPT-OSS parsers - register legacy and Harmony variants
let gpt_oss_legacy = Arc::new(GptOssParser::new());
let gpt_oss_harmony = Arc::new(GptOssHarmonyParser::new());
self.register_parser("gpt_oss_legacy", gpt_oss_legacy.clone());
self.register_parser("gpt_oss_harmony", gpt_oss_harmony.clone());
if use_harmony_gpt_oss() {
self.register_parser("gpt_oss", gpt_oss_harmony);
} else {
self.register_parser("gpt_oss", gpt_oss_legacy);
}
}
/// Register default model mappings
fn register_default_mappings(&mut self) {
// OpenAI models
self.map_model("gpt-4*", "json");
self.map_model("gpt-3.5*", "json");
self.map_model("gpt-4o*", "json");
// Anthropic models
self.map_model("claude-*", "json");
// Mistral models - use Mistral parser
self.map_model("mistral-*", "mistral");
self.map_model("mixtral-*", "mistral");
// Qwen models - use Qwen parser
self.map_model("qwen*", "qwen");
self.map_model("Qwen*", "qwen");
// Llama models
// Llama 4 uses pythonic format
self.map_model("llama-4*", "pythonic");
self.map_model("meta-llama-4*", "pythonic");
// Llama 3.2 uses python_tag format
self.map_model("llama-3.2*", "llama");
self.map_model("meta-llama-3.2*", "llama");
// Other Llama models use JSON
self.map_model("llama-*", "json");
self.map_model("meta-llama-*", "json");
// DeepSeek models
// DeepSeek V3 uses custom Unicode token format
self.map_model("deepseek-v3*", "deepseek");
self.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
// DeepSeek V2 uses pythonic format
self.map_model("deepseek-*", "pythonic");
// GLM models
// GLM-4.5 and GLM-4.6 uses XML-style format
self.map_model("glm-4.5*", "glm4_moe");
self.map_model("glm-4.6*", "glm4_moe");
// Other GLM models may use JSON
self.map_model("glm-*", "json");
// Step3 models
self.map_model("step3*", "step3");
self.map_model("Step-3*", "step3");
// Kimi models
self.map_model("kimi-k2*", "kimik2");
self.map_model("Kimi-K2*", "kimik2");
self.map_model("moonshot*/Kimi-K2*", "kimik2");
// GPT-OSS models (T4-style)
self.map_model("gpt-oss*", "gpt_oss");
self.map_model("t4-*", "gpt_oss");
// Other models default to JSON
self.map_model("gemini-*", "json");
self.map_model("palm-*", "json");
self.map_model("gemma-*", "json");
}
/// Set the default parser
pub fn set_default_parser(&mut self, name: impl Into<String>) {
self.default_parser = name.into();
}
/// Check if a parser is registered
pub fn has_parser(&self, name: &str) -> bool {
self.parsers.contains_key(name)
}
}
fn use_harmony_gpt_oss() -> bool {
env::var("ROUTER_USE_HARMONY_GPT_OSS")
.ok()
.map(|value| {
let normalized = value.trim();
matches!(
normalized,
"1" | "true" | "TRUE" | "True" | "yes" | "YES" | "Yes" | "on" | "ON" | "On"
)
})
.unwrap_or(false)
}
impl Default for &'static ParserRegistry {
fn default() -> Self {
ParserRegistry::new()
}
}
use crate::tool_parser::types::{PartialToolCall, ToolCall};
/// Current phase of parsing
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParsePhase {
/// Looking for start of tool call
Searching,
/// Parsing function name
InName,
/// Parsing function arguments
InArguments,
/// Tool call complete
Complete,
}
/// State for streaming parser
#[derive(Debug, Clone)]
pub struct ParseState {
/// Buffer for accumulating input
pub buffer: String,
/// Position of last consumed character
pub consumed: usize,
/// Current partial tool being parsed
pub partial_tool: Option<PartialToolCall>,
/// Completed tool calls
pub completed_tools: Vec<ToolCall>,
/// Current parsing phase
pub phase: ParsePhase,
/// Bracket/brace depth for JSON parsing
pub bracket_depth: i32,
/// Whether currently inside a string literal
pub in_string: bool,
/// Whether next character should be escaped
pub escape_next: bool,
/// Current tool index (for streaming)
pub tool_index: usize,
/// Optional Harmony-specific streaming state (populated by token-aware parsers)
pub harmony_stream: Option<HarmonyStreamState>,
}
impl ParseState {
/// Create a new parse state
pub fn new() -> Self {
Self {
buffer: String::new(),
consumed: 0,
partial_tool: None,
completed_tools: Vec::new(),
phase: ParsePhase::Searching,
bracket_depth: 0,
in_string: false,
escape_next: false,
tool_index: 0,
harmony_stream: None,
}
}
/// Reset state for parsing next tool
pub fn reset(&mut self) {
self.partial_tool = None;
self.phase = ParsePhase::Searching;
self.bracket_depth = 0;
self.in_string = false;
self.escape_next = false;
self.harmony_stream = None;
}
/// Process a single character for JSON parsing
pub fn process_char(&mut self, ch: char) {
// Handle escape sequences
if self.escape_next {
self.escape_next = false;
self.buffer.push(ch);
return;
}
if ch == '\\' && self.in_string {
self.escape_next = true;
self.buffer.push(ch);
return;
}
// Track string boundaries
if ch == '"' && !self.escape_next {
self.in_string = !self.in_string;
}
// Track bracket depth for JSON
if !self.in_string {
match ch {
'{' | '[' => {
self.bracket_depth += 1;
}
'}' | ']' => {
self.bracket_depth -= 1;
if self.bracket_depth == 0 && self.partial_tool.is_some() {
// Complete tool call found
self.phase = ParsePhase::Complete;
}
}
_ => {}
}
}
self.buffer.push(ch);
}
/// Check if we have a complete JSON object/array
pub fn has_complete_json(&self) -> bool {
self.bracket_depth == 0 && !self.in_string && !self.buffer.is_empty()
}
/// Extract content from buffer starting at position
pub fn extract_from(&self, start: usize) -> &str {
if start >= self.buffer.len() {
return "";
}
// Find the nearest character boundary at or after start
let mut safe_start = start;
while safe_start < self.buffer.len() && !self.buffer.is_char_boundary(safe_start) {
safe_start += 1;
}
if safe_start < self.buffer.len() {
&self.buffer[safe_start..]
} else {
""
}
}
/// Mark content as consumed up to position
pub fn consume_to(&mut self, position: usize) {
if position > self.consumed {
self.consumed = position;
}
}
/// Get unconsumed content
pub fn unconsumed(&self) -> &str {
if self.consumed >= self.buffer.len() {
return "";
}
// Find the nearest character boundary at or after consumed
let mut safe_consumed = self.consumed;
while safe_consumed < self.buffer.len() && !self.buffer.is_char_boundary(safe_consumed) {
safe_consumed += 1;
}
if safe_consumed < self.buffer.len() {
&self.buffer[safe_consumed..]
} else {
""
}
}
/// Clear consumed content from buffer
pub fn clear_consumed(&mut self) {
if self.consumed > 0 {
// Find the nearest character boundary at or before consumed
let mut safe_consumed = self.consumed;
while safe_consumed > 0 && !self.buffer.is_char_boundary(safe_consumed) {
safe_consumed -= 1;
}
if safe_consumed > 0 {
self.buffer.drain(..safe_consumed);
self.consumed = self.consumed.saturating_sub(safe_consumed);
}
}
}
/// Add completed tool
pub fn add_completed_tool(&mut self, tool: ToolCall) {
self.completed_tools.push(tool);
self.tool_index += 1;
}
}
impl Default for ParseState {
fn default() -> Self {
Self::new()
}
}
/// Placeholder for Harmony streaming metadata captured during token-aware parsing. /// Placeholder for Harmony streaming metadata captured during token-aware parsing.
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct HarmonyStreamState { pub struct HarmonyStreamState {
......
...@@ -5,64 +5,27 @@ use crate::tool_parser::partial_json::{ ...@@ -5,64 +5,27 @@ use crate::tool_parser::partial_json::{
}; };
use crate::tool_parser::traits::ToolParser; use crate::tool_parser::traits::ToolParser;
#[test] #[tokio::test]
fn test_parse_state_new() { async fn test_tool_parser_factory() {
let state = ParseState::new(); let factory = ToolParserFactory::new();
assert_eq!(state.phase, ParsePhase::Searching);
assert_eq!(state.buffer, "");
assert_eq!(state.consumed, 0);
assert_eq!(state.bracket_depth, 0);
assert!(!state.in_string);
assert!(!state.escape_next);
}
#[test]
fn test_parse_state_process_char() {
let mut state = ParseState::new();
state.process_char('{');
assert_eq!(state.bracket_depth, 1);
state.process_char('}');
assert_eq!(state.bracket_depth, 0);
state.process_char('"');
assert!(state.in_string);
state.process_char('"');
assert!(!state.in_string);
state.process_char('"');
state.process_char('\\');
assert!(state.escape_next);
state.process_char('"');
assert!(!state.escape_next);
assert!(state.in_string); // Still in string because quote was escaped
}
#[test]
fn test_parser_registry() {
let registry = ParserRegistry::new();
assert!(!registry.list_mappings().is_empty());
let mappings = registry.list_mappings(); // Test that we can get a pooled parser
let has_gpt = mappings.iter().any(|(m, _)| m.starts_with("gpt")); let pooled_parser = factory.get_pooled("gpt-4");
assert!(has_gpt); let parser = pooled_parser.lock().await;
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
} }
#[test] #[tokio::test]
fn test_parser_registry_pattern_matching() { async fn test_tool_parser_factory_model_mapping() {
let mut registry = ParserRegistry::new_for_testing(); let factory = ToolParserFactory::new();
registry.map_model("test-model", "json"); // Test model mapping
factory.registry().map_model("test-model", "json");
let mappings = registry.list_mappings(); // Get parser for the test model
let has_test = mappings let pooled_parser = factory.get_pooled("test-model");
.iter() let parser = pooled_parser.lock().await;
.any(|(m, p)| *m == "test-model" && *p == "json"); assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(has_test);
} }
#[test] #[test]
...@@ -165,37 +128,7 @@ fn test_compute_diff() { ...@@ -165,37 +128,7 @@ fn test_compute_diff() {
assert_eq!(compute_diff("test", "hello"), "hello"); assert_eq!(compute_diff("test", "hello"), "hello");
} }
#[test] // NOTE: test_stream_result_variants removed - StreamResult enum replaced by StreamingParseResult
fn test_stream_result_variants() {
let result = StreamResult::Incomplete;
matches!(result, StreamResult::Incomplete);
let result = StreamResult::ToolName {
index: 0,
name: "test".to_string(),
};
if let StreamResult::ToolName { index, name } = result {
assert_eq!(index, 0);
assert_eq!(name, "test");
} else {
panic!("Expected ToolName variant");
}
let tool = ToolCall {
id: "123".to_string(),
r#type: "function".to_string(),
function: FunctionCall {
name: "test".to_string(),
arguments: "{}".to_string(),
},
};
let result = StreamResult::ToolComplete(tool.clone());
if let StreamResult::ToolComplete(t) = result {
assert_eq!(t.id, "123");
} else {
panic!("Expected ToolComplete variant");
}
}
#[test] #[test]
fn test_partial_tool_call() { fn test_partial_tool_call() {
...@@ -310,14 +243,12 @@ fn test_json_parser_format_detection() { ...@@ -310,14 +243,12 @@ fn test_json_parser_format_detection() {
} }
#[tokio::test] #[tokio::test]
async fn test_registry_with_json_parser() { async fn test_factory_with_json_parser() {
let registry = ParserRegistry::new(); let factory = ToolParserFactory::new();
// JSON parser should be registered by default
assert!(registry.has_parser("json"));
// Should get JSON parser for OpenAI models // Should get JSON parser for OpenAI models
let parser = registry.get_parser("gpt-4-turbo").unwrap(); let pooled_parser = factory.get_pooled("gpt-4-turbo");
let parser = pooled_parser.lock().await;
let input = r#"{"name": "test", "arguments": {"x": 1}}"#; let input = r#"{"name": "test", "arguments": {"x": 1}}"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
...@@ -546,62 +477,6 @@ mod edge_cases { ...@@ -546,62 +477,6 @@ mod edge_cases {
assert!(tools[0].function.arguments.contains("null")); assert!(tools[0].function.arguments.contains("null"));
} }
#[tokio::test]
async fn test_streaming_with_partial_chunks() {
let parser = JsonParser::new();
let mut state1 = ParseState::new();
let partial = r#"{"#;
let result = parser
.parse_incremental(partial, &mut state1)
.await
.unwrap();
assert!(
matches!(result, StreamResult::Incomplete),
"Should return Incomplete for just opening brace"
);
let mut state2 = ParseState::new();
let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
let result = parser
.parse_incremental(complete, &mut state2)
.await
.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
let args: serde_json::Value =
serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["location"], "SF");
}
_ => panic!("Expected ToolComplete for complete JSON"),
}
// The PartialJson parser can complete partial JSON by filling in missing values
let mut state3 = ParseState::new();
let partial_with_name = r#"{"name": "test", "argum"#;
let result = parser
.parse_incremental(partial_with_name, &mut state3)
.await
.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "test");
// Arguments will be empty object since "argum" is incomplete
assert_eq!(tool.function.arguments, "{}");
}
StreamResult::ToolName { name, .. } => {
assert_eq!(name, "test");
}
StreamResult::Incomplete => {
// Also acceptable if parser decides to wait
}
_ => panic!("Unexpected result for partial JSON with name"),
}
}
#[tokio::test] #[tokio::test]
async fn test_special_json_values() { async fn test_special_json_values() {
let parser = JsonParser::new(); let parser = JsonParser::new();
......
use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::ToolParserResult, errors::ToolParserResult,
state::ParseState, types::{StreamingParseResult, ToolCall},
types::{StreamResult, ToolCall},
}; };
use async_trait::async_trait; use async_trait::async_trait;
...@@ -13,11 +13,16 @@ pub trait ToolParser: Send + Sync { ...@@ -13,11 +13,16 @@ pub trait ToolParser: Send + Sync {
async fn parse_complete(&self, output: &str) -> ToolParserResult<(String, Vec<ToolCall>)>; async fn parse_complete(&self, output: &str) -> ToolParserResult<(String, Vec<ToolCall>)>;
/// Parse tool calls from model output (streaming) /// Parse tool calls from model output (streaming)
/// Parsers now maintain internal state, so self is mutable
///
/// # Arguments
/// * `chunk` - New text chunk from model output
/// * `tools` - List of available tools for validation
async fn parse_incremental( async fn parse_incremental(
&self, &mut self,
chunk: &str, chunk: &str,
state: &mut ParseState, tools: &[Tool],
) -> ToolParserResult<StreamResult>; ) -> ToolParserResult<StreamingParseResult>;
/// Check if text contains tool calls in this parser's format /// Check if text contains tool calls in this parser's format
fn detect_format(&self, text: &str) -> bool; fn detect_format(&self, text: &str) -> bool;
...@@ -50,9 +55,10 @@ pub trait TokenToolParser: ToolParser { ...@@ -50,9 +55,10 @@ pub trait TokenToolParser: ToolParser {
) -> ToolParserResult<(String, Vec<ToolCall>)>; ) -> ToolParserResult<(String, Vec<ToolCall>)>;
/// Streaming parser entrypoint for token chunks. /// Streaming parser entrypoint for token chunks.
/// Parsers maintain internal state, so self is mutable
async fn parse_incremental_tokens( async fn parse_incremental_tokens(
&self, &mut self,
tokens: &[u32], tokens: &[u32],
state: &mut ParseState, tools: &[Tool],
) -> ToolParserResult<StreamResult>; ) -> ToolParserResult<StreamingParseResult>;
} }
...@@ -71,3 +71,23 @@ pub struct PartialToolCall { ...@@ -71,3 +71,23 @@ pub struct PartialToolCall {
/// Arguments already streamed /// Arguments already streamed
pub streamed_args: String, pub streamed_args: String,
} }
/// Result of streaming parse operation (matches Python StreamingParseResult)
#[derive(Debug, Clone, Default)]
pub struct StreamingParseResult {
/// Normal text that's not part of tool calls
pub normal_text: String,
/// Tool call items parsed from the chunk
pub calls: Vec<ToolCallItem>,
}
/// Simple encapsulation of parsed tool call for streaming (matches Python ToolCallItem)
#[derive(Debug, Clone)]
pub struct ToolCallItem {
/// Tool index in the array
pub tool_index: usize,
/// Tool name (only present on first chunk)
pub name: Option<String>,
/// Incremental JSON arguments
pub parameters: String,
}
...@@ -6,7 +6,9 @@ pub mod mock_openai_server; ...@@ -6,7 +6,9 @@ pub mod mock_openai_server;
pub mod mock_worker; pub mod mock_worker;
pub mod test_app; pub mod test_app;
use serde_json::json;
use sglang_router_rs::config::RouterConfig; use sglang_router_rs::config::RouterConfig;
use sglang_router_rs::protocols::spec::{Function, Tool};
use sglang_router_rs::server::AppContext; use sglang_router_rs::server::AppContext;
use std::fs; use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
...@@ -100,3 +102,284 @@ pub const EXPECTED_HASHES: [u64; 4] = [ ...@@ -100,3 +102,284 @@ pub const EXPECTED_HASHES: [u64; 4] = [
6245658446118930933, 6245658446118930933,
5097285695902185237, 5097285695902185237,
]; ];
/// Create a comprehensive set of test tools covering all parser test scenarios
#[allow(dead_code)]
pub fn create_test_tools() -> Vec<Tool> {
vec![
Tool {
tool_type: "function".to_string(),
function: Function {
name: "search".to_string(),
description: Some("Search for information".to_string()),
parameters: json!({
"type": "object",
"properties": {
"query": {"type": "string"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather information".to_string()),
parameters: json!({
"type": "object",
"properties": {
"city": {"type": "string"},
"location": {"type": "string"},
"date": {"type": "string"},
"units": {"type": "string"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "calculate".to_string(),
description: Some("Perform calculations".to_string()),
parameters: json!({
"type": "object",
"properties": {
"x": {"type": "number"},
"y": {"type": "number"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "translate".to_string(),
description: Some("Translate text".to_string()),
parameters: json!({
"type": "object",
"properties": {
"text": {"type": "string"},
"to": {"type": "string"},
"target_lang": {"type": "string"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_time".to_string(),
description: Some("Get current time".to_string()),
parameters: json!({
"type": "object",
"properties": {
"timezone": {"type": "string"},
"format": {"type": "string"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_current_time".to_string(),
description: Some("Get current time".to_string()),
parameters: json!({
"type": "object",
"properties": {
"timezone": {"type": "string"},
"format": {"type": "string"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "update_settings".to_string(),
description: Some("Update settings".to_string()),
parameters: json!({
"type": "object",
"properties": {
"preferences": {"type": "object"},
"notifications": {"type": "boolean"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "ping".to_string(),
description: Some("Ping service".to_string()),
parameters: json!({"type": "object", "properties": {}}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "test".to_string(),
description: Some("Test function".to_string()),
parameters: json!({"type": "object", "properties": {}}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "process".to_string(),
description: Some("Process data".to_string()),
parameters: json!({
"type": "object",
"properties": {
"count": {"type": "number"},
"rate": {"type": "number"},
"enabled": {"type": "boolean"},
"data": {"type": "object"},
"text": {"type": "string"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "web_search".to_string(),
description: Some("Search the web".to_string()),
parameters: json!({
"type": "object",
"properties": {
"query": {"type": "string"},
"num_results": {"type": "number"},
"search_type": {"type": "string"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_tourist_attractions".to_string(),
description: Some("Get tourist attractions".to_string()),
parameters: json!({
"type": "object",
"properties": {
"city": {"type": "string"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "config".to_string(),
description: Some("Configuration function".to_string()),
parameters: json!({
"type": "object",
"properties": {
"debug": {"type": "boolean"},
"verbose": {"type": "boolean"},
"optional": {"type": "null"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "test_func".to_string(),
description: Some("Test function".to_string()),
parameters: json!({
"type": "object",
"properties": {
"bool_true": {"type": "boolean"},
"bool_false": {"type": "boolean"},
"none_val": {"type": "null"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "create".to_string(),
description: Some("Create resource".to_string()),
parameters: json!({
"type": "object",
"properties": {
"name": {"type": "string"},
"email": {"type": "string"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "add".to_string(),
description: Some("Add operation".to_string()),
parameters: json!({
"type": "object",
"properties": {
"x": {"type": "number"},
"y": {"type": "number"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "calc".to_string(),
description: Some("Calculate".to_string()),
parameters: json!({
"type": "object",
"properties": {
"x": {"type": "number"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "func1".to_string(),
description: Some("Function 1".to_string()),
parameters: json!({"type": "object", "properties": {}}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "func2".to_string(),
description: Some("Function 2".to_string()),
parameters: json!({
"type": "object",
"properties": {
"y": {"type": "number"}
}
}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "tool1".to_string(),
description: Some("Tool 1".to_string()),
parameters: json!({"type": "object", "properties": {}}),
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "tool2".to_string(),
description: Some("Tool 2".to_string()),
parameters: json!({
"type": "object",
"properties": {
"y": {"type": "number"}
}
}),
},
},
]
}
//! DeepSeek V3 Parser Integration Tests //! DeepSeek V3 Parser Integration Tests
use sglang_router_rs::tool_parser::{DeepSeekParser, ParseState, StreamResult, ToolParser}; use sglang_router_rs::tool_parser::{DeepSeekParser, ToolParser};
mod common;
use common::create_test_tools;
#[tokio::test] #[tokio::test]
async fn test_deepseek_complete_parsing() { async fn test_deepseek_complete_parsing() {
...@@ -46,8 +49,9 @@ async fn test_deepseek_multiple_tools() { ...@@ -46,8 +49,9 @@ async fn test_deepseek_multiple_tools() {
#[tokio::test] #[tokio::test]
async fn test_deepseek_streaming() { async fn test_deepseek_streaming() {
let parser = DeepSeekParser::new(); let tools = create_test_tools();
let mut state = ParseState::new();
let mut parser = DeepSeekParser::new();
// Simulate streaming chunks // Simulate streaming chunks
let chunks = vec![ let chunks = vec![
...@@ -61,25 +65,19 @@ async fn test_deepseek_streaming() { ...@@ -61,25 +65,19 @@ async fn test_deepseek_streaming() {
]; ];
let mut found_name = false; let mut found_name = false;
let mut found_complete = false;
for chunk in chunks { for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap(); let result = parser.parse_incremental(chunk, &tools).await.unwrap();
match result { for call in result.calls {
StreamResult::ToolName { name, .. } => { if let Some(name) = call.name {
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
found_name = true; found_name = true;
} }
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
found_complete = true;
}
_ => {}
} }
} }
assert!(found_name || found_complete); assert!(found_name, "Should have found tool name during streaming");
} }
#[tokio::test] #[tokio::test]
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment