Unverified Commit 4fe53e58 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support streaming parsing with Tool Choice in chat completions API (#12677)

parent fb2e816e
...@@ -274,9 +274,6 @@ class TestToolChoiceLlama32(CustomTestCase): ...@@ -274,9 +274,6 @@ class TestToolChoiceLlama32(CustomTestCase):
self.assertIsNotNone(tool_calls) self.assertIsNotNone(tool_calls)
self.assertGreater(len(tool_calls), 0) self.assertGreater(len(tool_calls), 0)
@unittest.skip(
"Skipping required streaming test as it is not supported by the router"
)
def test_tool_choice_required_streaming(self): def test_tool_choice_required_streaming(self):
"""Test tool_choice='required' in streaming mode""" """Test tool_choice='required' in streaming mode"""
tools = self.get_test_tools() tools = self.get_test_tools()
...@@ -325,9 +322,6 @@ class TestToolChoiceLlama32(CustomTestCase): ...@@ -325,9 +322,6 @@ class TestToolChoiceLlama32(CustomTestCase):
for tool_call in tool_calls: for tool_call in tool_calls:
self.assertEqual(tool_call.function.name, "get_weather") self.assertEqual(tool_call.function.name, "get_weather")
@unittest.skip(
"Skipping required streaming test as it is not supported by the router"
)
def test_tool_choice_specific_function_streaming(self): def test_tool_choice_specific_function_streaming(self):
"""Test tool_choice with specific function in streaming mode""" """Test tool_choice with specific function in streaming mode"""
tools = self.get_test_tools() tools = self.get_test_tools()
...@@ -363,9 +357,6 @@ class TestToolChoiceLlama32(CustomTestCase): ...@@ -363,9 +357,6 @@ class TestToolChoiceLlama32(CustomTestCase):
self.assertEqual(found_name, "get_weather") self.assertEqual(found_name, "get_weather")
@unittest.skip(
"Skipping required streaming arguments chunks json test as it is not supported by the router"
)
def test_required_streaming_arguments_chunks_json(self): def test_required_streaming_arguments_chunks_json(self):
"""In streaming required mode, complete tool call arguments should be valid JSON when all chunks are combined""" """In streaming required mode, complete tool call arguments should be valid JSON when all chunks are combined"""
tools = self.get_test_tools() tools = self.get_test_tools()
......
...@@ -62,9 +62,15 @@ def test_power_of_two_prefers_less_loaded(mock_workers, router_manager): ...@@ -62,9 +62,15 @@ def test_power_of_two_prefers_less_loaded(mock_workers, router_manager):
except Exception: except Exception:
pass pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex: # Start background load in a non-blocking way to keep slow worker busy
list(ex.map(_direct_load, range(128))) background_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
time.sleep(1) background_futures = []
for i in range(32):
future = background_executor.submit(_direct_load, i)
background_futures.append(future)
# Wait longer for the load monitor to update (at least 2 monitor intervals)
time.sleep(3)
def call(i): def call(i):
r = requests.post( r = requests.post(
...@@ -85,6 +91,9 @@ def test_power_of_two_prefers_less_loaded(mock_workers, router_manager): ...@@ -85,6 +91,9 @@ def test_power_of_two_prefers_less_loaded(mock_workers, router_manager):
for wid in ex.map(call, range(200)): for wid in ex.map(call, range(200)):
counts[wid] += 1 counts[wid] += 1
# Clean up background executor
background_executor.shutdown(wait=False)
# Expect the slow worker (higher latency/inflight) to receive fewer requests # Expect the slow worker (higher latency/inflight) to receive fewer requests
fast_worker_id = [i for i in ids if i != slow_id][0] fast_worker_id = [i for i in ids if i != slow_id][0]
assert counts[slow_id] < counts[fast_worker_id], counts assert counts[slow_id] < counts[fast_worker_id], counts
//! Context for /v1/responses endpoint handlers //! Context for /v1/responses endpoint handlers
//! //!
//! Bundles all dependencies needed by responses handlers to avoid passing //! Bundles all dependencies needed by responses handlers to avoid passing
//! 10+ parameters to every function (fixes clippy::too_many_arguments). //! 10+ parameters to every function.
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
......
...@@ -210,6 +210,17 @@ impl StreamingProcessor { ...@@ -210,6 +210,17 @@ impl StreamingProcessor {
model, model,
); );
// Check if JSON schema constraint was used (specific function or required mode)
let used_json_schema = match tool_choice {
Some(ToolChoice::Function { .. }) => true,
Some(ToolChoice::Value(ToolChoiceValue::Required)) => true,
Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required",
_ => false,
};
// Check if this is the specific function case (LLM generates parameters only, no name field)
let is_specific_function = matches!(tool_choice, Some(ToolChoice::Function { .. }));
let tool_parser_available = tools.is_some() let tool_parser_available = tools.is_some()
&& utils::check_tool_parser_availability( && utils::check_tool_parser_availability(
&self.tool_parser_factory, &self.tool_parser_factory,
...@@ -342,10 +353,24 @@ impl StreamingProcessor { ...@@ -342,10 +353,24 @@ impl StreamingProcessor {
if !in_reasoning if !in_reasoning
&& tool_choice_enabled && tool_choice_enabled
&& tools.is_some() && tools.is_some()
&& tool_parser_available && (tool_parser_available || used_json_schema)
{ {
let tool_chunks = self let tool_chunks = if is_specific_function {
.process_tool_calls_stream( // Handle specific function case - emit tool call deltas with arguments
Self::process_specific_function_stream(
&delta,
index,
&mut has_tool_calls,
tool_choice,
request_id,
model,
created,
system_fingerprint,
history_tool_calls_count,
)
} else {
// Use incremental parser for regular/required modes
self.process_tool_calls_stream(
&delta, &delta,
index, index,
&mut tool_parsers, &mut tool_parsers,
...@@ -356,8 +381,10 @@ impl StreamingProcessor { ...@@ -356,8 +381,10 @@ impl StreamingProcessor {
created, created,
system_fingerprint, system_fingerprint,
history_tool_calls_count, history_tool_calls_count,
used_json_schema,
) )
.await; .await
};
for chunk in tool_chunks { for chunk in tool_chunks {
Self::format_sse_chunk_into(&mut sse_buffer, &chunk); Self::format_sse_chunk_into(&mut sse_buffer, &chunk);
...@@ -1089,6 +1116,101 @@ impl StreamingProcessor { ...@@ -1089,6 +1116,101 @@ impl StreamingProcessor {
(delta.to_string(), None, false) (delta.to_string(), None, false)
} }
/// Helper: Process specific function case - emit tool call deltas with arguments
#[allow(clippy::too_many_arguments)]
fn process_specific_function_stream(
delta: &str,
index: u32,
has_tool_calls: &mut HashMap<u32, bool>,
tool_choice: &Option<ToolChoice>,
request_id: &str,
model: &str,
created: u64,
system_fingerprint: Option<&str>,
history_tool_calls_count: usize,
) -> Vec<ChatCompletionStreamResponse> {
let mut chunks = Vec::new();
if let Some(ToolChoice::Function { function, .. }) = tool_choice {
let is_first_call = !has_tool_calls.contains_key(&index);
if is_first_call {
// First chunk: send name and id
has_tool_calls.insert(index, true);
let tool_call_id = utils::generate_tool_call_id(
model,
&function.name,
0,
history_tool_calls_count,
);
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: Some(tool_call_id),
tool_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(function.name.clone()),
arguments: None,
}),
}]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// Emit arguments delta
if !delta.is_empty() {
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: None,
tool_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(delta.to_string()),
}),
}]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
}
chunks
}
/// Helper: Process tool calls in streaming mode /// Helper: Process tool calls in streaming mode
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn process_tool_calls_stream( async fn process_tool_calls_stream(
...@@ -1103,17 +1225,27 @@ impl StreamingProcessor { ...@@ -1103,17 +1225,27 @@ impl StreamingProcessor {
created: u64, created: u64,
system_fingerprint: Option<&str>, system_fingerprint: Option<&str>,
history_tool_calls_count: usize, history_tool_calls_count: usize,
use_json_parser: bool,
) -> Vec<ChatCompletionStreamResponse> { ) -> Vec<ChatCompletionStreamResponse> {
let mut chunks = Vec::new(); let mut chunks = Vec::new();
// Create fresh parser for this index (not pooled, to avoid state pollution) // Create fresh parser for this index (not pooled, to avoid state pollution)
tool_parsers.entry(index).or_insert_with(|| { tool_parsers.entry(index).or_insert_with(|| {
let parser = utils::create_tool_parser( let parser = if use_json_parser {
&self.tool_parser_factory, utils::create_tool_parser(
self.configured_tool_parser.as_ref(), &self.tool_parser_factory,
model, Some(&"json".to_string()),
) model,
.expect("Parser should be available - checked upfront"); )
.expect("JSON parser should be available")
} else {
utils::create_tool_parser(
&self.tool_parser_factory,
self.configured_tool_parser.as_ref(),
model,
)
.expect("Parser should be available - checked upfront")
};
Arc::new(tokio::sync::Mutex::new(parser)) Arc::new(tokio::sync::Mutex::new(parser))
}); });
......
...@@ -21,6 +21,8 @@ use crate::{ ...@@ -21,6 +21,8 @@ use crate::{
/// - Unicode token delimiters /// - Unicode token delimiters
/// - JSON arguments in code blocks /// - JSON arguments in code blocks
/// - Support for multiple sequential tool calls /// - Support for multiple sequential tool calls
///
/// Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?chat_template=default
pub struct DeepSeekParser { pub struct DeepSeekParser {
/// Regex for extracting complete tool calls /// Regex for extracting complete tool calls
tool_call_extractor: Regex, tool_call_extractor: Regex,
......
...@@ -217,8 +217,18 @@ pub fn handle_json_tool_streaming( ...@@ -217,8 +217,18 @@ pub fn handle_json_tool_streaming(
} }
}; };
// Check if JSON is complete // Check if JSON is complete - validate only the parsed portion
let is_complete = end_idx == json_str.len() && serde_json::from_str::<Value>(json_str).is_ok(); // Ensure end_idx is on a valid UTF-8 character boundary
let safe_end_idx = if json_str.is_char_boundary(end_idx) {
end_idx
} else {
// Find the nearest valid character boundary before end_idx
(0..end_idx)
.rev()
.find(|&i| json_str.is_char_boundary(i))
.unwrap_or(0)
};
let is_complete = serde_json::from_str::<Value>(&json_str[..safe_end_idx]).is_ok();
// Validate tool name if present // Validate tool name if present
if let Some(name) = obj.get("name").and_then(|v| v.as_str()) { if let Some(name) = obj.get("name").and_then(|v| v.as_str()) {
......
...@@ -39,6 +39,12 @@ pub struct JsonParser { ...@@ -39,6 +39,12 @@ pub struct JsonParser {
/// Separator between multiple tool calls /// Separator between multiple tool calls
tool_call_separator: &'static str, tool_call_separator: &'static str,
/// Track whether we're parsing array format `[...]` vs single object `{...}`
is_array_format: bool,
/// Track whether we've already stripped the closing ] bracket (for array format)
array_closed: bool,
} }
impl JsonParser { impl JsonParser {
...@@ -52,6 +58,8 @@ impl JsonParser { ...@@ -52,6 +58,8 @@ impl JsonParser {
current_tool_name_sent: false, current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(), streamed_args_for_tool: Vec::new(),
tool_call_separator: ",", tool_call_separator: ",",
is_array_format: false,
array_closed: false,
} }
} }
...@@ -211,14 +219,31 @@ impl ToolParser for JsonParser { ...@@ -211,14 +219,31 @@ impl ToolParser for JsonParser {
self.buffer.push_str(chunk); self.buffer.push_str(chunk);
let current_text = &self.buffer.clone(); let current_text = &self.buffer.clone();
// Determine format on first parse (array vs single object)
if self.current_tool_id == -1 && self.has_tool_markers(current_text) {
self.is_array_format = current_text.trim().starts_with('[');
}
// Check if current_text has tool_call // Check if current_text has tool_call
let has_tool_start = self.has_tool_markers(current_text) // Once array is closed, don't treat [ or { as tool markers
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator)); let has_tool_start = (!self.array_closed && self.has_tool_markers(current_text))
|| (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start { if !has_tool_start {
let normal_text = self.buffer.clone(); let mut normal_text = self.buffer.clone();
self.buffer.clear(); self.buffer.clear();
// Strip ] only once (the closing bracket of JSON array format)
// Only for array format and only if we haven't already closed it
if self.is_array_format
&& !self.array_closed
&& self.current_tool_id > 0
&& normal_text.starts_with("]")
{
normal_text = normal_text.strip_prefix("]").unwrap().to_string();
self.array_closed = true;
}
return Ok(StreamingParseResult { return Ok(StreamingParseResult {
normal_text, normal_text,
calls: vec![], calls: vec![],
...@@ -233,12 +258,12 @@ impl ToolParser for JsonParser { ...@@ -233,12 +258,12 @@ impl ToolParser for JsonParser {
let start_idx = if let Some(bracket_pos) = current_text.find('[') { let start_idx = if let Some(bracket_pos) = current_text.find('[') {
let brace_pos = current_text.find('{'); let brace_pos = current_text.find('{');
match brace_pos { match brace_pos {
Some(bp) if bp < bracket_pos => bp, Some(bp) => bp,
_ => bracket_pos, _ => bracket_pos,
} }
} else if let Some(brace_pos) = current_text.find('{') { } else if let Some(brace_pos) = current_text.find('{') {
brace_pos brace_pos
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) { } else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len() self.tool_call_separator.len()
} else { } else {
0 0
...@@ -274,5 +299,7 @@ impl ToolParser for JsonParser { ...@@ -274,5 +299,7 @@ impl ToolParser for JsonParser {
&mut self.current_tool_name_sent, &mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool, &mut self.streamed_args_for_tool,
); );
self.is_array_format = false;
self.array_closed = false;
} }
} }
...@@ -21,6 +21,8 @@ use crate::{ ...@@ -21,6 +21,8 @@ use crate::{
/// - Token-based delimiters /// - Token-based delimiters
/// - Function calls with explicit indexing /// - Function calls with explicit indexing
/// - JSON arguments /// - JSON arguments
///
/// Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
pub struct KimiK2Parser { pub struct KimiK2Parser {
/// Regex for extracting complete tool calls /// Regex for extracting complete tool calls
tool_call_extractor: Regex, tool_call_extractor: Regex,
......
...@@ -181,7 +181,7 @@ impl ToolParser for LlamaParser { ...@@ -181,7 +181,7 @@ impl ToolParser for LlamaParser {
// Check if current_text has tool_call // Check if current_text has tool_call
let has_tool_start = self.has_tool_markers(current_text) let has_tool_start = self.has_tool_markers(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator)); || (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start { if !has_tool_start {
// Only clear buffer if we're sure no tool call is starting // Only clear buffer if we're sure no tool call is starting
...@@ -205,7 +205,7 @@ impl ToolParser for LlamaParser { ...@@ -205,7 +205,7 @@ impl ToolParser for LlamaParser {
// Determine start index for JSON parsing // Determine start index for JSON parsing
let start_idx = if let Some(pos) = current_text.find(self.bot_token) { let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
pos + self.bot_token.len() pos + self.bot_token.len()
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) { } else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len() self.tool_call_separator.len()
} else { } else {
0 0
......
...@@ -17,10 +17,7 @@ use crate::{ ...@@ -17,10 +17,7 @@ use crate::{
/// Handles the Mistral-specific format: /// Handles the Mistral-specific format:
/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]` /// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
/// ///
/// Features: /// Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default
/// - Bracket counting for proper JSON array extraction
/// - Support for multiple tool calls in a single array
/// - String-aware parsing to handle nested brackets in JSON
pub struct MistralParser { pub struct MistralParser {
/// Parser for handling incomplete JSON during streaming /// Parser for handling incomplete JSON during streaming
partial_json: PartialJson, partial_json: PartialJson,
...@@ -42,7 +39,11 @@ pub struct MistralParser { ...@@ -42,7 +39,11 @@ pub struct MistralParser {
/// Token configuration /// Token configuration
bot_token: &'static str, bot_token: &'static str,
eot_token: &'static str,
tool_call_separator: &'static str, tool_call_separator: &'static str,
/// Track whether we've already stripped the closing ] bracket
array_closed: bool,
} }
impl MistralParser { impl MistralParser {
...@@ -56,7 +57,9 @@ impl MistralParser { ...@@ -56,7 +57,9 @@ impl MistralParser {
current_tool_name_sent: false, current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(), streamed_args_for_tool: Vec::new(),
bot_token: "[TOOL_CALLS] [", bot_token: "[TOOL_CALLS] [",
eot_token: "]",
tool_call_separator: ", ", tool_call_separator: ", ",
array_closed: false,
} }
} }
...@@ -207,14 +210,27 @@ impl ToolParser for MistralParser { ...@@ -207,14 +210,27 @@ impl ToolParser for MistralParser {
// Check if current_text has tool_call // Check if current_text has tool_call
let has_tool_start = self.has_tool_markers(current_text) let has_tool_start = self.has_tool_markers(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator)); || (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start { if !has_tool_start {
// Only clear buffer if we're sure no tool call is starting // Only clear buffer if we're sure no tool call is starting
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() { if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
let normal_text = self.buffer.clone(); let mut normal_text = self.buffer.clone();
self.buffer.clear(); self.buffer.clear();
// Strip ] only once (the closing bracket of [TOOL_CALLS] array)
// current_tool_id > 0 means we've parsed at least one tool
if !self.array_closed
&& self.current_tool_id > 0
&& normal_text.starts_with(self.eot_token)
{
normal_text = normal_text
.strip_prefix(self.eot_token)
.unwrap()
.to_string();
self.array_closed = true;
}
return Ok(StreamingParseResult { return Ok(StreamingParseResult {
normal_text, normal_text,
calls: vec![], calls: vec![],
...@@ -231,7 +247,7 @@ impl ToolParser for MistralParser { ...@@ -231,7 +247,7 @@ impl ToolParser for MistralParser {
// Determine start index for JSON parsing // Determine start index for JSON parsing
let start_idx = if let Some(pos) = current_text.find(self.bot_token) { let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
pos + self.bot_token.len() pos + self.bot_token.len()
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) { } else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len() self.tool_call_separator.len()
} else { } else {
0 0
...@@ -266,5 +282,6 @@ impl ToolParser for MistralParser { ...@@ -266,5 +282,6 @@ impl ToolParser for MistralParser {
&mut self.current_tool_name_sent, &mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool, &mut self.streamed_args_for_tool,
); );
self.array_closed = false;
} }
} }
...@@ -9,6 +9,7 @@ use std::sync::OnceLock; ...@@ -9,6 +9,7 @@ use std::sync::OnceLock;
/// ///
/// This format is used by Llama models and uses Python literals /// This format is used by Llama models and uses Python literals
/// rather than JSON for arguments. /// rather than JSON for arguments.
/// Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct?chat_template=default
use async_trait::async_trait; use async_trait::async_trait;
use num_traits::ToPrimitive; use num_traits::ToPrimitive;
use regex::Regex; use regex::Regex;
......
...@@ -19,10 +19,11 @@ use crate::{ ...@@ -19,10 +19,11 @@ use crate::{
/// `<tool_call>\n{"name": "func", "arguments": {...}}\n</tool_call>` /// `<tool_call>\n{"name": "func", "arguments": {...}}\n</tool_call>`
/// ///
/// Features: /// Features:
/// - XML-style tags with JSON content /// - Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
/// - Support for multiple sequential tool calls /// - Each individual call is separated by `\n`
/// - Newline-aware parsing /// - Function Call Object: JSON object with "name" and "arguments" fields
/// - Buffering for partial end tokens ///
/// Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
pub struct QwenParser { pub struct QwenParser {
/// Parser for handling incomplete JSON during streaming /// Parser for handling incomplete JSON during streaming
partial_json: PartialJson, partial_json: PartialJson,
...@@ -49,8 +50,9 @@ pub struct QwenParser { ...@@ -49,8 +50,9 @@ pub struct QwenParser {
normal_text_buffer: String, normal_text_buffer: String,
/// Token configuration /// Token configuration
bot_token: &'static str, /// Start/end tokens for each individual tool call (not the entire sequence)
eot_token: &'static str, individual_tool_start_token: &'static str,
individual_tool_end_token: &'static str,
tool_call_separator: &'static str, tool_call_separator: &'static str,
} }
...@@ -70,8 +72,8 @@ impl QwenParser { ...@@ -70,8 +72,8 @@ impl QwenParser {
current_tool_name_sent: false, current_tool_name_sent: false,
streamed_args_for_tool: Vec::new(), streamed_args_for_tool: Vec::new(),
normal_text_buffer: String::new(), normal_text_buffer: String::new(),
bot_token: "<tool_call>\n", individual_tool_start_token: "<tool_call>\n",
eot_token: "\n</tool_call>", individual_tool_end_token: "\n</tool_call>",
tool_call_separator: "\n", tool_call_separator: "\n",
} }
} }
...@@ -157,11 +159,13 @@ impl ToolParser for QwenParser { ...@@ -157,11 +159,13 @@ impl ToolParser for QwenParser {
// Check if current_text has tool_call // Check if current_text has tool_call
let has_tool_start = self.has_tool_markers(current_text) let has_tool_start = self.has_tool_markers(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator)); || (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start { if !has_tool_start {
// Only clear buffer if we're sure no tool call is starting // Only clear buffer if we're sure no tool call is starting
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() { if helpers::ends_with_partial_token(&self.buffer, self.individual_tool_start_token)
.is_none()
{
let normal_text = self.buffer.clone(); let normal_text = self.buffer.clone();
self.buffer.clear(); self.buffer.clear();
...@@ -170,7 +174,7 @@ impl ToolParser for QwenParser { ...@@ -170,7 +174,7 @@ impl ToolParser for QwenParser {
calls: vec![], calls: vec![],
}); });
} else { } else {
// Might be partial bot_token, keep buffering // Might be partial individual_tool_start_token, keep buffering
return Ok(StreamingParseResult::default()); return Ok(StreamingParseResult::default());
} }
} }
...@@ -179,9 +183,9 @@ impl ToolParser for QwenParser { ...@@ -179,9 +183,9 @@ impl ToolParser for QwenParser {
let tool_indices = helpers::get_tool_indices(tools); let tool_indices = helpers::get_tool_indices(tools);
// Determine start index for JSON parsing // Determine start index for JSON parsing
let start_idx = if let Some(pos) = current_text.find(self.bot_token) { let start_idx = if let Some(pos) = current_text.find(self.individual_tool_start_token) {
pos + self.bot_token.len() pos + self.individual_tool_start_token.len()
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) { } else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
self.tool_call_separator.len() self.tool_call_separator.len()
} else { } else {
0 0
...@@ -205,7 +209,7 @@ impl ToolParser for QwenParser { ...@@ -205,7 +209,7 @@ impl ToolParser for QwenParser {
self.normal_text_buffer.push_str(&result.normal_text); self.normal_text_buffer.push_str(&result.normal_text);
// Check if buffer contains complete end token (without leading newline) // Check if buffer contains complete end token (without leading newline)
let end_token_without_newline = &self.eot_token[1..]; // "</tool_call>" let end_token_without_newline = &self.individual_tool_end_token[1..]; // "</tool_call>"
if self.normal_text_buffer.contains(end_token_without_newline) { if self.normal_text_buffer.contains(end_token_without_newline) {
// Complete end token found - clean it and return // Complete end token found - clean it and return
let cleaned_text = self let cleaned_text = self
......
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
use serde_json::json; use serde_json::json;
use sglang_router_rs::tool_parser::{JsonParser, ToolParser}; use sglang_router_rs::tool_parser::{JsonParser, ToolParser};
mod common;
use common::{create_test_tools, streaming_helpers::*};
#[tokio::test] #[tokio::test]
async fn test_simple_json_tool_call() { async fn test_simple_json_tool_call() {
let parser = JsonParser::new(); let parser = JsonParser::new();
...@@ -159,3 +162,556 @@ async fn test_json_format_detection() { ...@@ -159,3 +162,556 @@ async fn test_json_format_detection() {
assert!(parser.has_tool_markers(r#"[{"name": "test"}]"#)); assert!(parser.has_tool_markers(r#"[{"name": "test"}]"#));
assert!(!parser.has_tool_markers("plain text")); assert!(!parser.has_tool_markers("plain text"));
} }
// Streaming tests for JSON array format
#[tokio::test]
async fn test_json_array_streaming_required_mode() {
use sglang_router_rs::protocols::common::Tool;
// Test that simulates the exact streaming pattern from required mode
let mut parser = JsonParser::new();
// Define test tools
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
}];
// Simulate the EXACT chunks from the debug log
let chunks = vec![
"[{",
" \"",
"name",
"\":",
" \"",
"get",
"_weather",
"\",",
" \"",
"parameters",
"\":",
" {",
" \"",
"city",
"\":",
" \"",
"Paris",
"\"",
" }",
" }]",
];
let mut all_results = Vec::new();
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_results.extend(result.calls);
all_normal_text.push_str(&result.normal_text);
}
// We should have gotten tool call chunks
assert!(
!all_results.is_empty(),
"Should have emitted tool call chunks"
);
// Should not have emitted any normal text (including the closing ])
assert_eq!(
all_normal_text, "",
"Should not emit normal text for JSON array format"
);
// Check that we got the function name
let has_name = all_results
.iter()
.any(|item| item.name.as_ref().is_some_and(|n| n == "get_weather"));
assert!(has_name, "Should have emitted function name");
// Check that we got the parameters
let has_params = all_results.iter().any(|item| !item.parameters.is_empty());
assert!(has_params, "Should have emitted parameters");
}
#[tokio::test]
async fn test_json_array_multiple_tools_streaming() {
use sglang_router_rs::protocols::common::Tool;
// Test with multiple tools in array
let mut parser = JsonParser::new();
let tools = vec![
Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
},
Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_news".to_string(),
description: Some("Get news".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
},
];
// Split into smaller, more realistic chunks
let chunks = vec![
"[{",
"\"name\":",
"\"get_weather\"",
",\"parameters\":",
"{\"city\":",
"\"SF\"}",
"}",
",",
"{\"name\":",
"\"get_news\"",
",\"parameters\":",
"{\"topic\":",
"\"tech\"}",
"}]",
];
let mut all_results = Vec::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_results.extend(result.calls);
}
// Should have gotten tool calls for both functions
let has_weather = all_results
.iter()
.any(|item| item.name.as_ref().is_some_and(|n| n == "get_weather"));
let has_news = all_results
.iter()
.any(|item| item.name.as_ref().is_some_and(|n| n == "get_news"));
assert!(has_weather, "Should have get_weather tool call");
assert!(has_news, "Should have get_news tool call");
}
#[tokio::test]
async fn test_json_array_closing_bracket_separate_chunk() {
use sglang_router_rs::protocols::common::Tool;
// Test case where the closing ] comes as a separate chunk
let mut parser = JsonParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}];
// Closing ] as separate chunk, followed by normal text
let chunks = vec![
"[{",
"\"",
"name",
"\":",
"\"",
"get",
"_weather",
"\",",
"\"",
"parameters",
"\":",
"{",
"\"",
"city",
"\":",
"\"",
"Paris",
"\"",
"}",
"}",
"]",
" Here's",
" the",
" weather",
" info",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// Should emit only the third chunk as normal text, NOT the ]
assert_eq!(
all_normal_text, " Here's the weather info",
"Should emit only normal text without ], got: '{}'",
all_normal_text
);
}
#[tokio::test]
async fn test_json_single_object_with_trailing_text() {
use sglang_router_rs::protocols::common::Tool;
// Test single object format (no array) with trailing text
let mut parser = JsonParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
}];
let chunks = vec![
"{",
"\"",
"name",
"\":",
"\"",
"get_weather",
"\",",
"\"",
"parameters",
"\":",
"{",
"\"city",
"\":",
"\"Paris",
"\"}",
"}",
" Here's",
" the",
" weather",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// Should emit the trailing text as normal_text (no ] to strip for single object)
assert_eq!(
all_normal_text, " Here's the weather",
"Should emit normal text for single object format, got: '{}'",
all_normal_text
);
}
#[tokio::test]
async fn test_json_single_object_with_bracket_in_text() {
use sglang_router_rs::protocols::common::Tool;
// Test that ] in normal text is NOT stripped for single object format
let mut parser = JsonParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
}];
let chunks = vec![
"{",
"\"name",
"\":",
"\"get_weather",
"\",",
"\"parameters",
"\":",
"{",
"\"city",
"\":",
"\"Paris",
"\"}",
"}",
"]",
" Here's",
" the",
" weather",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// For single object format, ] should NOT be stripped (it's part of normal text)
assert_eq!(
all_normal_text, "] Here's the weather",
"Should preserve ] in normal text for single object format, got: '{}'",
all_normal_text
);
}
#[tokio::test]
async fn test_json_array_bracket_in_text_after_tools() {
use sglang_router_rs::protocols::common::Tool;
// Test that ] in normal text AFTER array tools is preserved
let mut parser = JsonParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: serde_json::json!({}),
strict: None,
},
}];
let chunks = vec![
"[",
"{",
"\"name",
"\":",
"\"get_weather",
"\",",
"\"parameters",
"\":",
"{",
"\"city",
"\":",
"\"Paris",
"\"}",
"}",
"]",
" Array",
" notation:",
" arr",
"[",
"0",
"]",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// Should preserve ] in normal text after array tools complete
assert_eq!(
all_normal_text, " Array notation: arr[0]",
"Should preserve ] in normal text after array tools, got: '{}'",
all_normal_text
);
}
// =============================================================================
// REALISTIC STREAMING TESTS
// =============================================================================
#[tokio::test]
async fn test_json_bug_incomplete_tool_name_string() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
// This exact sequence triggered the bug:
// Parser receives {"name": " and must NOT parse it as empty name
let chunks = vec![
r#"{"#,
r#"""#,
r#"name"#,
r#"""#,
r#":"#,
r#" "#,
r#"""#, // ← Critical moment: parser has {"name": "
// At this point, partial_json should NOT allow incomplete strings
// when current_tool_name_sent=false
r#"search"#, // Use valid tool name from create_test_tools()
r#"""#,
r#", "#,
r#"""#,
r#"arguments"#,
r#"""#,
r#": {"#,
r#"""#,
r#"query"#,
r#"""#,
r#": "#,
r#"""#,
r#"rust programming"#,
r#"""#,
r#"}}"#,
];
let mut got_tool_name = false;
let mut saw_empty_name = false;
for chunk in chunks.iter() {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = &call.name {
if name.is_empty() {
saw_empty_name = true;
}
if name == "search" {
got_tool_name = true;
}
}
}
}
assert!(
!saw_empty_name,
"Parser should NEVER return empty tool name"
);
assert!(got_tool_name, "Should have parsed tool name correctly");
}
#[tokio::test]
async fn test_json_realistic_chunks_simple_tool() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 10, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "get_weather");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_strategic_chunks_with_quotes() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "rust programming"}}"#;
let chunks = create_strategic_chunks(input);
// Strategic chunks break after quotes and colons
assert!(chunks.iter().any(|c| c.ends_with('"')));
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_incremental_arguments_streaming() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "test", "limit": 10}}"#;
let chunks = create_realistic_chunks(input);
let mut tool_name_sent = false;
let mut got_arguments = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
tool_name_sent = true;
}
if tool_name_sent && !call.parameters.is_empty() {
got_arguments = true;
}
}
}
assert!(tool_name_sent, "Should have sent tool name");
assert!(got_arguments, "Should have sent arguments");
}
#[tokio::test]
async fn test_json_very_long_url_in_arguments() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
// Simulate long URL arriving in many chunks
let long_url = "https://example.com/very/long/path/".to_string() + &"segment/".repeat(50);
let input = format!(
r#"{{"name": "search", "arguments": {{"query": "{}"}}}}"#,
long_url
);
let chunks = create_realistic_chunks(&input);
assert!(chunks.len() > 100, "Long URL should create many chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_unicode() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "Hello 世界 🌍"}}"#;
let chunks = create_realistic_chunks(input);
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed with unicode");
}
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
use sglang_router_rs::tool_parser::{LlamaParser, ToolParser}; use sglang_router_rs::tool_parser::{LlamaParser, ToolParser};
mod common; mod common;
use common::create_test_tools; use common::{create_test_tools, streaming_helpers::*};
#[tokio::test] #[tokio::test]
async fn test_llama_python_tag_format() { async fn test_llama_python_tag_format() {
...@@ -397,3 +397,59 @@ async fn test_llama_streaming_multiple_tools_chunked() { ...@@ -397,3 +397,59 @@ async fn test_llama_streaming_multiple_tools_chunked() {
} }
} }
} }
// =============================================================================
// REALISTIC STREAMING TESTS
// =============================================================================
#[tokio::test]
async fn test_llama_realistic_chunks_with_python_tag() {
let tools = create_test_tools();
let mut parser = LlamaParser::new();
let input = r#"<|python_tag|>{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 15, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "calculate");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_llama_python_tag_arrives_in_parts() {
let tools = create_test_tools();
let mut parser = LlamaParser::new();
// Python tag itself arrives in small chunks
let chunks = vec![
"<|p", "yth", "on_", "tag", "|>{", r#"""#, "na", r#"me""#, ": ", r#"""#, "sea", "rch",
r#"""#, ", ", r#"""#, "par", "ame", "ter", "s", r#"""#, ": {", r#"""#, "q", r#"""#, ": ",
r#"""#, "tes", "t", r#"""#, "}}",
];
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "search");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
...@@ -155,3 +155,120 @@ Let me execute these searches for you."#; ...@@ -155,3 +155,120 @@ Let me execute these searches for you."#;
assert_eq!(tools[0].function.name, "web_search"); assert_eq!(tools[0].function.name, "web_search");
assert_eq!(tools[1].function.name, "get_weather"); assert_eq!(tools[1].function.name, "get_weather");
} }
#[tokio::test]
async fn test_mistral_streaming_closing_bracket() {
use sglang_router_rs::protocols::common::Tool;
// Test that closing ] is stripped for Mistral array format
let mut parser = MistralParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}];
let chunks = vec![
"[TOOL_CALLS] ",
"[{",
"\"",
"name",
"\":",
"\"",
"get",
"_weather",
"\",",
"\"",
"arguments",
"\":",
"{",
"\"",
"city",
"\":",
"\"",
"Paris",
"\"",
"}",
"}",
"]",
" Here's",
" the weather",
" info",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// Should emit only the third chunk as normal text, NOT the ]
assert_eq!(
all_normal_text, " Here's the weather info",
"Should not emit ] for Mistral array format, got: '{}'",
all_normal_text
);
}
#[tokio::test]
async fn test_mistral_streaming_bracket_in_text_after_tools() {
use sglang_router_rs::protocols::common::Tool;
// Test that ] in normal text AFTER tool calls is preserved
let mut parser = MistralParser::new();
let tools = vec![Tool {
tool_type: "function".to_string(),
function: sglang_router_rs::protocols::common::Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}];
let chunks = vec![
"[TOOL_CALLS] ",
"[",
"{",
"\"name",
"\":",
"\"get_weather",
"\",",
"\"arguments",
"\":",
"{\"",
"city",
"\":",
"\"Paris",
"\"}",
"}",
"]",
" Array",
" notation:",
" arr",
"[",
"0",
"]",
];
let mut all_normal_text = String::new();
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
all_normal_text.push_str(&result.normal_text);
}
// Should preserve ] in normal text after tools complete
assert_eq!(
all_normal_text, " Array notation: arr[0]",
"Should preserve ] in normal text after tools, got: '{}'",
all_normal_text
);
}
...@@ -6,7 +6,7 @@ use serde_json::json; ...@@ -6,7 +6,7 @@ use serde_json::json;
use sglang_router_rs::tool_parser::{QwenParser, ToolParser}; use sglang_router_rs::tool_parser::{QwenParser, ToolParser};
mod common; mod common;
use common::create_test_tools; use common::{create_test_tools, streaming_helpers::*};
#[tokio::test] #[tokio::test]
async fn test_qwen_single_tool() { async fn test_qwen_single_tool() {
...@@ -250,3 +250,58 @@ async fn test_buffer_efficiency_with_multiple_tools() { ...@@ -250,3 +250,58 @@ async fn test_buffer_efficiency_with_multiple_tools() {
} }
} }
} }
// =============================================================================
// REALISTIC STREAMING TESTS
// =============================================================================
#[tokio::test]
async fn test_qwen_realistic_chunks_with_xml_tags() {
let tools = create_test_tools();
let mut parser = QwenParser::new();
let input = "<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Tokyo\"}}\n</tool_call>";
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 20, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "get_weather");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_qwen_xml_tag_arrives_in_parts() {
let tools = create_test_tools();
let mut parser = QwenParser::new();
let chunks = vec![
"<to", "ol_", "cal", "l>\n", "{", r#"""#, "na", "me", r#"""#, ": ", r#"""#, "tra", "nsl",
"ate", r#"""#, ", ", r#"""#, "arg", "ume", "nts", r#"""#, ": {", r#"""#, "tex", "t",
r#"""#, ": ", r#"""#, "hel", "lo", r#"""#, "}}\n", "</t", "ool", "_ca", "ll>",
];
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "translate");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
//! Realistic Streaming Parser Tests
//!
//! Tests incremental parsing with realistic char-level chunks (2-5 chars)
//! that simulate how LLM tokens actually arrive.
//!
//! These tests are designed to catch bugs like `{"name": "` being parsed
//! as an empty tool name.
use sglang_router_rs::tool_parser::{JsonParser, LlamaParser, QwenParser, ToolParser};
mod common;
use common::{create_test_tools, streaming_helpers::*};
// =============================================================================
// THE BUG SCENARIO - Most Critical Test
// =============================================================================
#[tokio::test]
async fn test_json_bug_incomplete_tool_name_string() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
// This exact sequence triggered the bug:
// Parser receives {"name": " and must NOT parse it as empty name
let chunks = vec![
r#"{"#,
r#"""#,
r#"name"#,
r#"""#,
r#":"#,
r#" "#,
r#"""#, // ← Critical moment: parser has {"name": "
// At this point, partial_json should NOT allow incomplete strings
// when current_tool_name_sent=false
r#"search"#, // Use valid tool name from create_test_tools()
r#"""#,
r#", "#,
r#"""#,
r#"arguments"#,
r#"""#,
r#": {"#,
r#"""#,
r#"query"#,
r#"""#,
r#": "#,
r#"""#,
r#"rust programming"#,
r#"""#,
r#"}}"#,
];
let mut got_tool_name = false;
let mut saw_empty_name = false;
for chunk in chunks.iter() {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = &call.name {
if name.is_empty() {
saw_empty_name = true;
}
if name == "search" {
got_tool_name = true;
}
}
}
}
assert!(
!saw_empty_name,
"Parser should NEVER return empty tool name"
);
assert!(got_tool_name, "Should have parsed tool name correctly");
}
// =============================================================================
// JSON PARSER REALISTIC STREAMING
// =============================================================================
#[tokio::test]
async fn test_json_realistic_chunks_simple_tool() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 10, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "get_weather");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_strategic_chunks_with_quotes() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "rust programming"}}"#;
let chunks = create_strategic_chunks(input);
// Strategic chunks break after quotes and colons
assert!(chunks.iter().any(|c| c.ends_with('"')));
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_incremental_arguments_streaming() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "test", "limit": 10}}"#;
let chunks = create_realistic_chunks(input);
let mut tool_name_sent = false;
let mut got_arguments = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
tool_name_sent = true;
}
if tool_name_sent && !call.parameters.is_empty() {
got_arguments = true;
}
}
}
assert!(tool_name_sent, "Should have sent tool name");
assert!(got_arguments, "Should have sent arguments");
}
// =============================================================================
// LLAMA PARSER REALISTIC STREAMING
// =============================================================================
#[tokio::test]
async fn test_llama_realistic_chunks_with_python_tag() {
let tools = create_test_tools();
let mut parser = LlamaParser::new();
let input = r#"<|python_tag|>{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 15, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "calculate");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_llama_python_tag_arrives_in_parts() {
let tools = create_test_tools();
let mut parser = LlamaParser::new();
// Python tag itself arrives in small chunks
let chunks = vec![
"<|p", "yth", "on_", "tag", "|>{", r#"""#, "na", r#"me""#, ": ", r#"""#, "sea", "rch",
r#"""#, ", ", r#"""#, "par", "ame", "ter", "s", r#"""#, ": {", r#"""#, "q", r#"""#, ": ",
r#"""#, "tes", "t", r#"""#, "}}",
];
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "search");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
// =============================================================================
// QWEN PARSER REALISTIC STREAMING
// =============================================================================
#[tokio::test]
async fn test_qwen_realistic_chunks_with_xml_tags() {
let tools = create_test_tools();
let mut parser = QwenParser::new();
let input = "<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Tokyo\"}}\n</tool_call>";
let chunks = create_realistic_chunks(input);
assert!(chunks.len() > 20, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "get_weather");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_qwen_xml_tag_arrives_in_parts() {
let tools = create_test_tools();
let mut parser = QwenParser::new();
let chunks = vec![
"<to", "ol_", "cal", "l>\n", "{", r#"""#, "na", "me", r#"""#, ": ", r#"""#, "tra", "nsl",
"ate", r#"""#, ", ", r#"""#, "arg", "ume", "nts", r#"""#, ": {", r#"""#, "tex", "t",
r#"""#, ": ", r#"""#, "hel", "lo", r#"""#, "}}\n", "</t", "ool", "_ca", "ll>",
];
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
for call in result.calls {
if let Some(name) = call.name {
assert_eq!(name, "translate");
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
// =============================================================================
// EDGE CASES WITH REALISTIC CHUNKS
// =============================================================================
#[tokio::test]
async fn test_json_very_long_url_in_arguments() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
// Simulate long URL arriving in many chunks
let long_url = "https://example.com/very/long/path/".to_string() + &"segment/".repeat(50);
let input = format!(
r#"{{"name": "search", "arguments": {{"query": "{}"}}}}"#,
long_url
);
let chunks = create_realistic_chunks(&input);
assert!(chunks.len() > 100, "Long URL should create many chunks");
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_json_unicode_arrives_byte_by_byte() {
let tools = create_test_tools();
let mut parser = JsonParser::new();
let input = r#"{"name": "search", "arguments": {"query": "Hello 世界 🌍"}}"#;
let chunks = create_realistic_chunks(input);
let mut got_tool_name = false;
for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
}
}
assert!(got_tool_name, "Should have parsed with unicode");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment