Unverified Commit 4fe53e58 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support streaming parsing with Tool Choice in chat completions API (#12677)

parent fb2e816e
......@@ -274,9 +274,6 @@ class TestToolChoiceLlama32(CustomTestCase):
self.assertIsNotNone(tool_calls)
self.assertGreater(len(tool_calls), 0)
@unittest.skip(
"Skipping required streaming test as it is not supported by the router"
)
def test_tool_choice_required_streaming(self):
"""Test tool_choice='required' in streaming mode"""
tools = self.get_test_tools()
......@@ -325,9 +322,6 @@ class TestToolChoiceLlama32(CustomTestCase):
for tool_call in tool_calls:
self.assertEqual(tool_call.function.name, "get_weather")
@unittest.skip(
"Skipping required streaming test as it is not supported by the router"
)
def test_tool_choice_specific_function_streaming(self):
"""Test tool_choice with specific function in streaming mode"""
tools = self.get_test_tools()
......@@ -363,9 +357,6 @@ class TestToolChoiceLlama32(CustomTestCase):
self.assertEqual(found_name, "get_weather")
@unittest.skip(
"Skipping required streaming arguments chunks json test as it is not supported by the router"
)
def test_required_streaming_arguments_chunks_json(self):
"""In streaming required mode, complete tool call arguments should be valid JSON when all chunks are combined"""
tools = self.get_test_tools()
......
......@@ -62,9 +62,15 @@ def test_power_of_two_prefers_less_loaded(mock_workers, router_manager):
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_direct_load, range(128)))
time.sleep(1)
# Start background load in a non-blocking way to keep slow worker busy
background_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
background_futures = []
for i in range(32):
future = background_executor.submit(_direct_load, i)
background_futures.append(future)
# Wait longer for the load monitor to update (at least 2 monitor intervals)
time.sleep(3)
def call(i):
r = requests.post(
......@@ -85,6 +91,9 @@ def test_power_of_two_prefers_less_loaded(mock_workers, router_manager):
for wid in ex.map(call, range(200)):
counts[wid] += 1
# Clean up background executor
background_executor.shutdown(wait=False)
# Expect the slow worker (higher latency/inflight) to receive fewer requests
fast_worker_id = [i for i in ids if i != slow_id][0]
assert counts[slow_id] < counts[fast_worker_id], counts
//! Context for /v1/responses endpoint handlers
//!
//! Bundles all dependencies needed by responses handlers to avoid passing
//! 10+ parameters to every function (fixes clippy::too_many_arguments).
//! 10+ parameters to every function.
use std::{collections::HashMap, sync::Arc};
......
......@@ -210,6 +210,17 @@ impl StreamingProcessor {
model,
);
// Check if JSON schema constraint was used (specific function or required mode)
let used_json_schema = match tool_choice {
Some(ToolChoice::Function { .. }) => true,
Some(ToolChoice::Value(ToolChoiceValue::Required)) => true,
Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required",
_ => false,
};
// Check if this is the specific function case (LLM generates parameters only, no name field)
let is_specific_function = matches!(tool_choice, Some(ToolChoice::Function { .. }));
let tool_parser_available = tools.is_some()
&& utils::check_tool_parser_availability(
&self.tool_parser_factory,
......@@ -342,10 +353,24 @@ impl StreamingProcessor {
if !in_reasoning
&& tool_choice_enabled
&& tools.is_some()
&& tool_parser_available
&& (tool_parser_available || used_json_schema)
{
let tool_chunks = self
.process_tool_calls_stream(
let tool_chunks = if is_specific_function {
// Handle specific function case - emit tool call deltas with arguments
Self::process_specific_function_stream(
&delta,
index,
&mut has_tool_calls,
tool_choice,
request_id,
model,
created,
system_fingerprint,
history_tool_calls_count,
)
} else {
// Use incremental parser for regular/required modes
self.process_tool_calls_stream(
&delta,
index,
&mut tool_parsers,
......@@ -356,8 +381,10 @@ impl StreamingProcessor {
created,
system_fingerprint,
history_tool_calls_count,
used_json_schema,
)
.await;
.await
};
for chunk in tool_chunks {
Self::format_sse_chunk_into(&mut sse_buffer, &chunk);
......@@ -1089,6 +1116,101 @@ impl StreamingProcessor {
(delta.to_string(), None, false)
}
/// Helper: Process specific function case - emit tool call deltas with arguments
#[allow(clippy::too_many_arguments)]
fn process_specific_function_stream(
delta: &str,
index: u32,
has_tool_calls: &mut HashMap<u32, bool>,
tool_choice: &Option<ToolChoice>,
request_id: &str,
model: &str,
created: u64,
system_fingerprint: Option<&str>,
history_tool_calls_count: usize,
) -> Vec<ChatCompletionStreamResponse> {
let mut chunks = Vec::new();
if let Some(ToolChoice::Function { function, .. }) = tool_choice {
let is_first_call = !has_tool_calls.contains_key(&index);
if is_first_call {
// First chunk: send name and id
has_tool_calls.insert(index, true);
let tool_call_id = utils::generate_tool_call_id(
model,
&function.name,
0,
history_tool_calls_count,
);
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: Some(tool_call_id),
tool_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(function.name.clone()),
arguments: None,
}),
}]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// Emit arguments delta
if !delta.is_empty() {
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: None,
tool_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(delta.to_string()),
}),
}]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
}
chunks
}
/// Helper: Process tool calls in streaming mode
#[allow(clippy::too_many_arguments)]
async fn process_tool_calls_stream(
......@@ -1103,17 +1225,27 @@ impl StreamingProcessor {
created: u64,
system_fingerprint: Option<&str>,
history_tool_calls_count: usize,
use_json_parser: bool,
) -> Vec<ChatCompletionStreamResponse> {
let mut chunks = Vec::new();
// Create fresh parser for this index (not pooled, to avoid state pollution)
tool_parsers.entry(index).or_insert_with(|| {
let parser = utils::create_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
)
.expect("Parser should be available - checked upfront");
let parser = if use_json_parser {
utils::create_tool_parser(
&self.tool_parser_factory,
Some(&"json".to_string()),
model,
)
.expect("JSON parser should be available")
} else {
utils::create_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
)
.expect("Parser should be available - checked upfront")
};
Arc::new(tokio::sync::Mutex::new(parser))
});
......
......@@ -21,6 +21,8 @@ use crate::{
/// - Unicode token delimiters
/// - JSON arguments in code blocks
/// - Support for multiple sequential tool calls
///
/// Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?chat_template=default
pub struct DeepSeekParser {
/// Regex for extracting complete tool calls
tool_call_extractor: Regex,
......
......@@ -217,8 +217,18 @@ pub fn handle_json_tool_streaming(
}
};
// Check if JSON is complete
let is_complete = end_idx == json_str.len() && serde_json::from_str::<Value>(json_str).is_ok();
// Check if JSON is complete - validate only the parsed portion
// Ensure end_idx is on a valid UTF-8 character boundary
let safe_end_idx = if json_str.is_char_boundary(end_idx) {
end_idx
} else {
// Find the nearest valid character boundary before end_idx
(0..end_idx)
.rev()
.find(|&i| json_str.is_char_boundary(i))
.unwrap_or(0)
};
let is_complete = serde_json::from_str::<Value>(&json_str[..safe_end_idx]).is_ok();
// Validate tool name if present
if let Some(name) = obj.get("name").and_then(|v| v.as_str()) {
......
......@@ -39,6 +39,12 @@ pub struct JsonParser {
/// Separator between multiple tool calls
tool_call_separator: &'static str,
/// Track whether we're parsing array format `[...]` vs single object `{...}`
is_array_format: bool,
/// Track whether we've already stripped the closing ] bracket (for array format)
array_closed: bool,
}
impl JsonParser {
......@@ -52,6 +58,8 @@ impl JsonParser {
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
tool_call_separator: ",",
is_array_format: false,
array_closed: false,
}
}
......@@ -211,14 +219,31 @@ impl ToolParser for JsonParser {
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
// Determine format on first parse (array vs single object)
if self.current_tool_id == -1 && self.has_tool_markers(current_text) {
self.is_array_format = current_text.trim().starts_with('[');
}
// Check if current_text has tool_call
let has_tool_start = self.has_tool_markers(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
// Once array is closed, don't treat [ or { as tool markers
let has_tool_start = (!self.array_closed && self.has_tool_markers(current_text))
|| (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start {
let normal_text = self.buffer.clone();
let mut normal_text = self.buffer.clone();
self.buffer.clear();
// Strip ] only once (the closing bracket of JSON array format)
// Only for array format and only if we haven't already closed it
if self.is_array_format
&& !self.array_closed
&& self.current_tool_id > 0
&& normal_text.starts_with("]")
{
normal_text = normal_text.strip_prefix("]").unwrap().to_string();
self.array_closed = true;
}
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
......@@ -233,12 +258,12 @@ impl ToolParser for JsonParser {
let start_idx = if let Some(bracket_pos) = current_text.find('[') {
let brace_pos = current_text.find('{');
match brace_pos {
Some(bp) if bp < bracket_pos => bp,
Some(bp) => bp,
_ => bracket_pos,
}
} else if let Some(brace_pos) = current_text.find('{') {
brace_pos
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
} else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len()
} else {
0
......@@ -274,5 +299,7 @@ impl ToolParser for JsonParser {
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
);
self.is_array_format = false;
self.array_closed = false;
}
}
......@@ -21,6 +21,8 @@ use crate::{
/// - Token-based delimiters
/// - Function calls with explicit indexing
/// - JSON arguments
///
/// Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
pub struct KimiK2Parser {
/// Regex for extracting complete tool calls
tool_call_extractor: Regex,
......
......@@ -181,7 +181,7 @@ impl ToolParser for LlamaParser {
// Check if current_text has tool_call
let has_tool_start = self.has_tool_markers(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
|| (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start {
// Only clear buffer if we're sure no tool call is starting
......@@ -205,7 +205,7 @@ impl ToolParser for LlamaParser {
// Determine start index for JSON parsing
let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
pos + self.bot_token.len()
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
} else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len()
} else {
0
......
......@@ -17,10 +17,7 @@ use crate::{
/// Handles the Mistral-specific format:
/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
///
/// Features:
/// - Bracket counting for proper JSON array extraction
/// - Support for multiple tool calls in a single array
/// - String-aware parsing to handle nested brackets in JSON
/// Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default
pub struct MistralParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
......@@ -42,7 +39,11 @@ pub struct MistralParser {
/// Token configuration
bot_token: &'static str,
eot_token: &'static str,
tool_call_separator: &'static str,
/// Track whether we've already stripped the closing ] bracket
array_closed: bool,
}
impl MistralParser {
......@@ -56,7 +57,9 @@ impl MistralParser {
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
bot_token: "[TOOL_CALLS] [",
eot_token: "]",
tool_call_separator: ", ",
array_closed: false,
}
}
......@@ -207,14 +210,27 @@ impl ToolParser for MistralParser {
// Check if current_text has tool_call
let has_tool_start = self.has_tool_markers(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
|| (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start {
// Only clear buffer if we're sure no tool call is starting
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
let normal_text = self.buffer.clone();
let mut normal_text = self.buffer.clone();
self.buffer.clear();
// Strip ] only once (the closing bracket of [TOOL_CALLS] array)
// current_tool_id > 0 means we've parsed at least one tool
if !self.array_closed
&& self.current_tool_id > 0
&& normal_text.starts_with(self.eot_token)
{
normal_text = normal_text
.strip_prefix(self.eot_token)
.unwrap()
.to_string();
self.array_closed = true;
}
return Ok(StreamingParseResult {
normal_text,
calls: vec![],
......@@ -231,7 +247,7 @@ impl ToolParser for MistralParser {
// Determine start index for JSON parsing
let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
pos + self.bot_token.len()
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
} else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len()
} else {
0
......@@ -266,5 +282,6 @@ impl ToolParser for MistralParser {
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
);
self.array_closed = false;
}
}
......@@ -9,6 +9,7 @@ use std::sync::OnceLock;
///
/// This format is used by Llama models and uses Python literals
/// rather than JSON for arguments.
/// Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct?chat_template=default
use async_trait::async_trait;
use num_traits::ToPrimitive;
use regex::Regex;
......
......@@ -19,10 +19,11 @@ use crate::{
/// `<tool_call>\n{"name": "func", "arguments": {...}}\n</tool_call>`
///
/// Features:
/// - XML-style tags with JSON content
/// - Support for multiple sequential tool calls
/// - Newline-aware parsing
/// - Buffering for partial end tokens
/// - Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
/// - Each individual call is separated by `\n`
/// - Function Call Object: JSON object with "name" and "arguments" fields
///
/// Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
pub struct QwenParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
......@@ -49,8 +50,9 @@ pub struct QwenParser {
normal_text_buffer: String,
/// Token configuration
bot_token: &'static str,
eot_token: &'static str,
/// Start/end tokens for each individual tool call (not the entire sequence)
individual_tool_start_token: &'static str,
individual_tool_end_token: &'static str,
tool_call_separator: &'static str,
}
......@@ -70,8 +72,8 @@ impl QwenParser {
current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(),
normal_text_buffer: String::new(),
bot_token: "<tool_call>\n",
eot_token: "\n</tool_call>",
individual_tool_start_token: "<tool_call>\n",
individual_tool_end_token: "\n</tool_call>",
tool_call_separator: "\n",
}
}
......@@ -157,11 +159,13 @@ impl ToolParser for QwenParser {
// Check if current_text has tool_call
let has_tool_start = self.has_tool_markers(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
|| (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start {
// Only clear buffer if we're sure no tool call is starting
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
if helpers::ends_with_partial_token(&self.buffer, self.individual_tool_start_token)
.is_none()
{
let normal_text = self.buffer.clone();
self.buffer.clear();
......@@ -170,7 +174,7 @@ impl ToolParser for QwenParser {
calls: vec![],
});
} else {
// Might be partial bot_token, keep buffering
// Might be partial individual_tool_start_token, keep buffering
return Ok(StreamingParseResult::default());
}
}
......@@ -179,9 +183,9 @@ impl ToolParser for QwenParser {
let tool_indices = helpers::get_tool_indices(tools);
// Determine start index for JSON parsing
let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
pos + self.bot_token.len()
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
let start_idx = if let Some(pos) = current_text.find(self.individual_tool_start_token) {
pos + self.individual_tool_start_token.len()
} else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len()
} else {
0
......@@ -205,7 +209,7 @@ impl ToolParser for QwenParser {
self.normal_text_buffer.push_str(&result.normal_text);
// Check if buffer contains complete end token (without leading newline)
let end_token_without_newline = &self.eot_token[1..]; // "</tool_call>"
let end_token_without_newline = &self.individual_tool_end_token[1..]; // "</tool_call>"
if self.normal_text_buffer.contains(end_token_without_newline) {
// Complete end token found - clean it and return
let cleaned_text = self
......
......@@ -5,6 +5,9 @@
use serde_json::json;
use sglang_router_rs::tool_parser::{JsonParser, ToolParser};
mod common;
use common::{create_test_tools, streaming_helpers::*};
#[tokio::test]
async fn test_simple_json_tool_call() {
let parser = JsonParser::new();
......@@ -159,3 +162,556 @@ async fn test_json_format_detection() {
assert!(parser.has_tool_markers(r#"[{"name": "test"}]"#));
assert!(!parser.has_tool_markers("plain text"));
}
// Streaming tests for JSON array format
#[tokio::test]
async fn test_json_array_streaming_required_mode() {
use sglang_router_rs::protocols::common::Tool;
// Test that simulates the exact streaming pattern from required mode
let mut parser = JsonParser::new();
// Define test tools
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
}];
// Simulate the EXACT chunks from the debug log
let chunks = vec![
"[{",
" \"",
"name",
"\":",
" \"",
"get",
"_weather",
"\",",
" \"",
"parameters",
"\":",
" {",
" \"",
"city",
"\":",
" \"",
"Paris",
"\"",
" }",
" }]",
];
let mut all_results = Vec::new();
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_results.extend(result.calls);
all_normal_text.push_str(&result.normal_text);
}
// We should have gotten tool call chunks
assert!(
!all_results.is_empty(),
"Should have emitted tool call chunks"
);
// Should not have emitted any normal text (including the closing ])
assert_eq!(
all_normal_text, "",
"Should not emit normal text for JSON array format"
);
// Check that we got the function name
let has_name = all_results
.iter()
.any(|item| item.name.as_ref().is_some_and(|n| n == "get_weather"));
assert!(has_name, "Should have emitted function name");
// Check that we got the parameters
let has_params = all_results.iter().any(|item| !item.parameters.is_empty());
assert!(has_params, "Should have emitted parameters");
}
#[tokio::test]
async fn test_json_array_multiple_tools_streaming() {
use sglang_router_rs::protocols::common::Tool;
// Test with multiple tools in array
let mut parser = JsonParser::new();
let tools = vec![
Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
},
Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_news".to_string(),
description: Some("Get news".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
},
];
// Split into smaller, more realistic chunks
let chunks = vec![
"[{",
"\"name\":",
"\"get_weather\"",
",\"parameters\":",
"{\"city\":",
"\"SF\"}",
"}",
",",
"{\"name\":",
"\"get_news\"",
",\"parameters\":",
"{\"topic\":",
"\"tech\"}",
"}]",
];
let mut all_results = Vec::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_results.extend(result.calls);
}
// Should have gotten tool calls for both functions
let has_weather = all_results
.iter()
.any(|item| item.name.as_ref().is_some_and(|n| n == "get_weather"));
let has_news = all_results
.iter()
.any(|item| item.name.as_ref().is_some_and(|n| n == "get_news"));
assert!(has_weather, "Should have get_weather tool call");
assert!(has_news, "Should have get_news tool call");
}
#[tokio::test]
async fn test_json_array_closing_bracket_separate_chunk() {
use sglang_router_rs::protocols::common::Tool;
// Test case where the closing ] comes as a separate chunk
let mut parser = JsonParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}];
// Closing ] as separate chunk, followed by normal text
let chunks = vec![
"[{",
"\"",
"name",
"\":",
"\"",
"get",
"_weather",
"\",",
"\"",
"parameters",
"\":",
"{",
"\"",
"city",
"\":",
"\"",
"Paris",
"\"",
"}",
"}",
"]",
" Here's",
" the",
" weather",
" info",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// Should emit only the third chunk as normal text, NOT the ]
assert_eq!(
all_normal_text, " Here's the weather info",
"Should emit only normal text without ], got: '{}'",
all_normal_text
);
}
#[tokio::test]
async fn test_json_single_object_with_trailing_text() {
use sglang_router_rs::protocols::common::Tool;
// Test single object format (no array) with trailing text
let mut parser = JsonParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
}];
let chunks = vec![
"{",
"\"",
"name",
"\":",
"\"",
"get_weather",
"\",",
"\"",
"parameters",
"\":",
"{",
"\"city",
"\":",
"\"Paris",
"\"}",
"}",
" Here's",
" the",
" weather",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// Should emit the trailing text as normal_text (no ] to strip for single object)
assert_eq!(
all_normal_text, " Here's the weather",
"Should emit normal text for single object format, got: '{}'",
all_normal_text
);
}
#[tokio::test]
async fn test_json_single_object_with_bracket_in_text() {
use sglang_router_rs::protocols::common::Tool;
// Test that ] in normal text is NOT stripped for single object format
let mut parser = JsonParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
}];
let chunks = vec![
"{",
"\"name",
"\":",
"\"get_weather",
"\",",
"\"parameters",
"\":",
"{",
"\"city",
"\":",
"\"Paris",
"\"}",
"}",
"]",
" Here's",
" the",
" weather",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// For single object format, ] should NOT be stripped (it's part of normal text)
assert_eq!(
all_normal_text, "] Here's the weather",
"Should preserve ] in normal text for single object format, got: '{}'",
all_normal_text
);
}
#[tokio::test]
async fn test_json_array_bracket_in_text_after_tools() {
use sglang_router_rs::protocols::common::Tool;
// Test that ] in normal text AFTER array tools is preserved
let mut parser = JsonParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
}];
let chunks = vec![
"[",
"{",
"\"name",
"\":",
"\"get_weather",
"\",",
"\"parameters",
"\":",
"{",
"\"city",
"\":",
"\"Paris",
"\"}",
"}",
"]",
" Array",
" notation:",
" arr",
"[",
"0",
"]",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// Should preserve ] in normal text after array tools complete
assert_eq!(
all_normal_text, " Array notation: arr[0]",
"Should preserve ] in normal text after array tools, got: '{}'",
all_normal_text
);
}
// =============================================================================
// REALISTIC STREAMING TESTS
// =============================================================================
#[tokio::test]
async fn test_json_bug_incomplete_tool_name_string() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
// This exact sequence triggered the bug:
// Parser receives {"name": " and must NOT parse it as empty name
let chunks = vec![
r#"{"#,
r#"""#,
r#"name"#,
r#"""#,
r#":"#,
r#" "#,
r#"""#, // ← Critical moment: parser has {"name": "
// At this point, partial_json should NOT allow incomplete strings
// when current_tool_name_sent=false
r#"search"#, // Use valid tool name from create_test_tools()
r#"""#,
r#", "#,
r#"""#,
r#"arguments"#,
r#"""#,
r#": {"#,
r#"""#,
r#"query"#,
r#"""#,
r#": "#,
r#"""#,
r#"rust programming"#,
r#"""#,
r#"}}"#,
];
let mut got_tool_name = false;
let mut saw_empty_name = false;
for chunk in chunks.iter() {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = &call.name {
if name.is_empty() {
saw_empty_name = true;
}
if name == "search" {
got_tool_name = true;
}
}
}
}
assert!(
!saw_empty_name,
"Parser should NEVER return empty tool name"
);
assert!(got_tool_name, "Should have parsed tool name correctly");
}
#[tokio::test]
async fn test_json_realistic_chunks_simple_tool() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 10, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "get_weather");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_strategic_chunks_with_quotes() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "rust programming"}}"#;
let chunks = create_strategic_chunks(input);
// Strategic chunks break after quotes and colons
assert!(chunks.iter().any(|c| c.ends_with('"')));
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_incremental_arguments_streaming() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "test", "limit": 10}}"#;
let chunks = create_realistic_chunks(input);
let mut tool_name_sent = false;
let mut got_arguments = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
tool_name_sent = true;
}
if tool_name_sent && !call.parameters.is_empty() {
got_arguments = true;
}
}
}
assert!(tool_name_sent, "Should have sent tool name");
assert!(got_arguments, "Should have sent arguments");
}
#[tokio::test]
async fn test_json_very_long_url_in_arguments() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
// Simulate long URL arriving in many chunks
let long_url = "https://example.com/very/long/path/".to_string() + &"segment/".repeat(50);
let input = format!(
r#"{{"name": "search", "arguments": {{"query": "{}"}}}}"#,
long_url
);
let chunks = create_realistic_chunks(&input);
assert!(chunks.len() > 100, "Long URL should create many chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_unicode() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "Hello 世界 🌍"}}"#;
let chunks = create_realistic_chunks(input);
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed with unicode");
}
......@@ -5,7 +5,7 @@
use sglang_router_rs::tool_parser::{LlamaParser, ToolParser};
mod common;
use common::create_test_tools;
use common::{create_test_tools, streaming_helpers::*};
#[tokio::test]
async fn test_llama_python_tag_format() {
......@@ -397,3 +397,59 @@ async fn test_llama_streaming_multiple_tools_chunked() {
}
}
}
// =============================================================================
// REALISTIC STREAMING TESTS
// =============================================================================
#[tokio::test]
async fn test_llama_realistic_chunks_with_python_tag() {
let tools = create_test_tools();
let mut parser = LlamaParser::new();
let input = r#"<|python_tag|>{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 15, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "calculate");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_llama_python_tag_arrives_in_parts() {
let tools = create_test_tools();
let mut parser = LlamaParser::new();
// Python tag itself arrives in small chunks
let chunks = vec![
"<|p", "yth", "on_", "tag", "|>{", r#"""#, "na", r#"me""#, ": ", r#"""#, "sea", "rch",
r#"""#, ", ", r#"""#, "par", "ame", "ter", "s", r#"""#, ": {", r#"""#, "q", r#"""#, ": ",
r#"""#, "tes", "t", r#"""#, "}}",
];
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "search");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
......@@ -155,3 +155,120 @@ Let me execute these searches for you."#;
assert_eq!(tools[0].function.name, "web_search");
assert_eq!(tools[1].function.name, "get_weather");
}
#[tokio::test]
async fn test_mistral_streaming_closing_bracket() {
use sglang_router_rs::protocols::common::Tool;
// Test that closing ] is stripped for Mistral array format
let mut parser = MistralParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}];
let chunks = vec![
"[TOOL_CALLS] ",
"[{",
"\"",
"name",
"\":",
"\"",
"get",
"_weather",
"\",",
"\"",
"arguments",
"\":",
"{",
"\"",
"city",
"\":",
"\"",
"Paris",
"\"",
"}",
"}",
"]",
" Here's",
" the weather",
" info",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// Should emit only the third chunk as normal text, NOT the ]
assert_eq!(
all_normal_text, " Here's the weather info",
"Should not emit ] for Mistral array format, got: '{}'",
all_normal_text
);
}
#[tokio::test]
async fn test_mistral_streaming_bracket_in_text_after_tools() {
use sglang_router_rs::protocols::common::Tool;
// Test that ] in normal text AFTER tool calls is preserved
let mut parser = MistralParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}];
let chunks = vec![
"[TOOL_CALLS] ",
"[",
"{",
"\"name",
"\":",
"\"get_weather",
"\",",
"\"arguments",
"\":",
"{\"",
"city",
"\":",
"\"Paris",
"\"}",
"}",
"]",
" Array",
" notation:",
" arr",
"[",
"0",
"]",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// Should preserve ] in normal text after tools complete
assert_eq!(
all_normal_text, " Array notation: arr[0]",
"Should preserve ] in normal text after tools, got: '{}'",
all_normal_text
);
}
......@@ -6,7 +6,7 @@ use serde_json::json;
use sglang_router_rs::tool_parser::{QwenParser, ToolParser};
mod common;
use common::create_test_tools;
use common::{create_test_tools, streaming_helpers::*};
#[tokio::test]
async fn test_qwen_single_tool() {
......@@ -250,3 +250,58 @@ async fn test_buffer_efficiency_with_multiple_tools() {
}
}
}
// =============================================================================
// REALISTIC STREAMING TESTS
// =============================================================================
#[tokio::test]
async fn test_qwen_realistic_chunks_with_xml_tags() {
let tools = create_test_tools();
let mut parser = QwenParser::new();
let input = "<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Tokyo\"}}\n</tool_call>";
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 20, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "get_weather");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_qwen_xml_tag_arrives_in_parts() {
let tools = create_test_tools();
let mut parser = QwenParser::new();
let chunks = vec![
"<to", "ol_", "cal", "l>\n", "{", r#"""#, "na", "me", r#"""#, ": ", r#"""#, "tra", "nsl",
"ate", r#"""#, ", ", r#"""#, "arg", "ume", "nts", r#"""#, ": {", r#"""#, "tex", "t",
r#"""#, ": ", r#"""#, "hel", "lo", r#"""#, "}}\n", "</t", "ool", "_ca", "ll>",
];
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "translate");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
//! Realistic Streaming Parser Tests
//!
//! Tests incremental parsing with realistic char-level chunks (2-5 chars)
//! that simulate how LLM tokens actually arrive.
//!
//! These tests are designed to catch bugs like `{"name": "` being parsed
//! as an empty tool name.
use sglang_router_rs::tool_parser::{JsonParser, LlamaParser, QwenParser, ToolParser};
mod common;
use common::{create_test_tools, streaming_helpers::*};
// =============================================================================
// THE BUG SCENARIO - Most Critical Test
// =============================================================================
#[tokio::test]
async fn test_json_bug_incomplete_tool_name_string() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
// This exact sequence triggered the bug:
// Parser receives {"name": " and must NOT parse it as empty name
let chunks = vec![
r#"{"#,
r#"""#,
r#"name"#,
r#"""#,
r#":"#,
r#" "#,
r#"""#, // ← Critical moment: parser has {"name": "
// At this point, partial_json should NOT allow incomplete strings
// when current_tool_name_sent=false
r#"search"#, // Use valid tool name from create_test_tools()
r#"""#,
r#", "#,
r#"""#,
r#"arguments"#,
r#"""#,
r#": {"#,
r#"""#,
r#"query"#,
r#"""#,
r#": "#,
r#"""#,
r#"rust programming"#,
r#"""#,
r#"}}"#,
];
let mut got_tool_name = false;
let mut saw_empty_name = false;
for chunk in chunks.iter() {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = &call.name {
if name.is_empty() {
saw_empty_name = true;
}
if name == "search" {
got_tool_name = true;
}
}
}
}
assert!(
!saw_empty_name,
"Parser should NEVER return empty tool name"
);
assert!(got_tool_name, "Should have parsed tool name correctly");
}
// =============================================================================
// JSON PARSER REALISTIC STREAMING
// =============================================================================
#[tokio::test]
async fn test_json_realistic_chunks_simple_tool() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 10, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "get_weather");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_strategic_chunks_with_quotes() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "rust programming"}}"#;
let chunks = create_strategic_chunks(input);
// Strategic chunks break after quotes and colons
assert!(chunks.iter().any(|c| c.ends_with('"')));
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_incremental_arguments_streaming() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "test", "limit": 10}}"#;
let chunks = create_realistic_chunks(input);
let mut tool_name_sent = false;
let mut got_arguments = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
tool_name_sent = true;
}
if tool_name_sent && !call.parameters.is_empty() {
got_arguments = true;
}
}
}
assert!(tool_name_sent, "Should have sent tool name");
assert!(got_arguments, "Should have sent arguments");
}
// =============================================================================
// LLAMA PARSER REALISTIC STREAMING
// =============================================================================
#[tokio::test]
async fn test_llama_realistic_chunks_with_python_tag() {
let tools = create_test_tools();
let mut parser = LlamaParser::new();
let input = r#"<|python_tag|>{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 15, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "calculate");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_llama_python_tag_arrives_in_parts() {
let tools = create_test_tools();
let mut parser = LlamaParser::new();
// Python tag itself arrives in small chunks
let chunks = vec![
"<|p", "yth", "on_", "tag", "|>{", r#"""#, "na", r#"me""#, ": ", r#"""#, "sea", "rch",
r#"""#, ", ", r#"""#, "par", "ame", "ter", "s", r#"""#, ": {", r#"""#, "q", r#"""#, ": ",
r#"""#, "tes", "t", r#"""#, "}}",
];
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "search");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
// =============================================================================
// QWEN PARSER REALISTIC STREAMING
// =============================================================================
#[tokio::test]
async fn test_qwen_realistic_chunks_with_xml_tags() {
let tools = create_test_tools();
let mut parser = QwenParser::new();
let input = "<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Tokyo\"}}\n</tool_call>";
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 20, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "get_weather");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_qwen_xml_tag_arrives_in_parts() {
let tools = create_test_tools();
let mut parser = QwenParser::new();
let chunks = vec![
"<to", "ol_", "cal", "l>\n", "{", r#"""#, "na", "me", r#"""#, ": ", r#"""#, "tra", "nsl",
"ate", r#"""#, ", ", r#"""#, "arg", "ume", "nts", r#"""#, ": {", r#"""#, "tex", "t",
r#"""#, ": ", r#"""#, "hel", "lo", r#"""#, "}}\n", "</t", "ool", "_ca", "ll>",
];
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "translate");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
// =============================================================================
// EDGE CASES WITH REALISTIC CHUNKS
// =============================================================================
#[tokio::test]
async fn test_json_very_long_url_in_arguments() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
// Simulate long URL arriving in many chunks
let long_url = "https://example.com/very/long/path/".to_string() + &"segment/".repeat(50);
let input = format!(
r#"{{"name": "search", "arguments": {{"query": "{}"}}}}"#,
long_url
);
let chunks = create_realistic_chunks(&input);
assert!(chunks.len() > 100, "Long URL should create many chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_unicode_arrives_byte_by_byte() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "Hello 世界 🌍"}}"#;
let chunks = create_realistic_chunks(input);
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed with unicode");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment