Unverified Commit ab926dd6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Fix streaming bugs: empty tool names, state pollution, and panics (#11373)

parent a4b424c6
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput}; use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
use serde_json::json; use serde_json::json;
use sglang_router_rs::protocols::spec::{Function, Tool}; use sglang_router_rs::protocols::spec::{Function, Tool};
use sglang_router_rs::tool_parser::{JsonParser, ToolParser, ToolParserFactory}; use sglang_router_rs::tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser};
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
......
...@@ -82,9 +82,15 @@ impl ParserRegistry { ...@@ -82,9 +82,15 @@ impl ParserRegistry {
} }
} }
/// Get a parser by exact name (creates new instance, not pooled). /// Check if a parser with the given name is registered.
/// Use this for compatibility or when you need a fresh instance. pub fn has_parser(&self, name: &str) -> bool {
pub fn get_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> { let creators = self.creators.read().unwrap();
creators.contains_key(name)
}
/// Create a fresh parser instance by exact name (not pooled).
/// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
pub fn create_parser(&self, name: &str) -> Option<Box<dyn ReasoningParser>> {
let creators = self.creators.read().unwrap(); let creators = self.creators.read().unwrap();
creators.get(name).map(|creator| creator()) creators.get(name).map(|creator| creator())
} }
...@@ -102,14 +108,30 @@ impl ParserRegistry { ...@@ -102,14 +108,30 @@ impl ParserRegistry {
None None
} }
/// Find a parser for a given model ID by pattern matching (creates new instance). /// Check if a parser can be created for a specific model without actually creating it.
pub fn find_parser_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> { /// Returns true if a parser is available (registered) for this model.
pub fn has_parser_for_model(&self, model_id: &str) -> bool {
let patterns = self.patterns.read().unwrap();
let model_lower = model_id.to_lowercase();
for (pattern, parser_name) in patterns.iter() {
if model_lower.contains(&pattern.to_lowercase()) {
let creators = self.creators.read().unwrap();
return creators.contains_key(parser_name);
}
}
false
}
/// Create a fresh parser instance for a given model ID by pattern matching (not pooled).
/// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
pub fn create_for_model(&self, model_id: &str) -> Option<Box<dyn ReasoningParser>> {
let patterns = self.patterns.read().unwrap(); let patterns = self.patterns.read().unwrap();
let model_lower = model_id.to_lowercase(); let model_lower = model_id.to_lowercase();
for (pattern, parser_name) in patterns.iter() { for (pattern, parser_name) in patterns.iter() {
if model_lower.contains(&pattern.to_lowercase()) { if model_lower.contains(&pattern.to_lowercase()) {
return self.get_parser(parser_name); return self.create_parser(parser_name);
} }
} }
None None
...@@ -131,11 +153,11 @@ impl Default for ParserRegistry { ...@@ -131,11 +153,11 @@ impl Default for ParserRegistry {
/// Factory for creating reasoning parsers based on model type. /// Factory for creating reasoning parsers based on model type.
#[derive(Clone)] #[derive(Clone)]
pub struct ReasoningParserFactory { pub struct ParserFactory {
registry: ParserRegistry, registry: ParserRegistry,
} }
impl ReasoningParserFactory { impl ParserFactory {
/// Create a new factory with default parsers registered. /// Create a new factory with default parsers registered.
pub fn new() -> Self { pub fn new() -> Self {
let registry = ParserRegistry::new(); let registry = ParserRegistry::new();
...@@ -211,7 +233,7 @@ impl ReasoningParserFactory { ...@@ -211,7 +233,7 @@ impl ReasoningParserFactory {
/// Use this when you need an isolated parser instance. /// Use this when you need an isolated parser instance.
pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> { pub fn create(&self, model_id: &str) -> Result<Box<dyn ReasoningParser>, ParseError> {
// First try to find by pattern // First try to find by pattern
if let Some(parser) = self.registry.find_parser_for_model(model_id) { if let Some(parser) = self.registry.create_for_model(model_id) {
return Ok(parser); return Ok(parser);
} }
...@@ -240,7 +262,7 @@ impl ReasoningParserFactory { ...@@ -240,7 +262,7 @@ impl ReasoningParserFactory {
} }
} }
impl Default for ReasoningParserFactory { impl Default for ParserFactory {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
...@@ -252,35 +274,35 @@ mod tests { ...@@ -252,35 +274,35 @@ mod tests {
#[test] #[test]
fn test_factory_creates_deepseek_r1() { fn test_factory_creates_deepseek_r1() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
let parser = factory.create("deepseek-r1-distill").unwrap(); let parser = factory.create("deepseek-r1-distill").unwrap();
assert_eq!(parser.model_type(), "deepseek_r1"); assert_eq!(parser.model_type(), "deepseek_r1");
} }
#[test] #[test]
fn test_factory_creates_qwen3() { fn test_factory_creates_qwen3() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
let parser = factory.create("qwen3-7b").unwrap(); let parser = factory.create("qwen3-7b").unwrap();
assert_eq!(parser.model_type(), "qwen3"); assert_eq!(parser.model_type(), "qwen3");
} }
#[test] #[test]
fn test_factory_creates_kimi() { fn test_factory_creates_kimi() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
let parser = factory.create("kimi-chat").unwrap(); let parser = factory.create("kimi-chat").unwrap();
assert_eq!(parser.model_type(), "kimi"); assert_eq!(parser.model_type(), "kimi");
} }
#[test] #[test]
fn test_factory_fallback_to_passthrough() { fn test_factory_fallback_to_passthrough() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
let parser = factory.create("unknown-model").unwrap(); let parser = factory.create("unknown-model").unwrap();
assert_eq!(parser.model_type(), "passthrough"); assert_eq!(parser.model_type(), "passthrough");
} }
#[test] #[test]
fn test_case_insensitive_matching() { fn test_case_insensitive_matching() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
let parser1 = factory.create("DeepSeek-R1").unwrap(); let parser1 = factory.create("DeepSeek-R1").unwrap();
let parser2 = factory.create("QWEN3").unwrap(); let parser2 = factory.create("QWEN3").unwrap();
let parser3 = factory.create("Kimi").unwrap(); let parser3 = factory.create("Kimi").unwrap();
...@@ -292,21 +314,21 @@ mod tests { ...@@ -292,21 +314,21 @@ mod tests {
#[test] #[test]
fn test_step3_model() { fn test_step3_model() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
let step3 = factory.create("step3-model").unwrap(); let step3 = factory.create("step3-model").unwrap();
assert_eq!(step3.model_type(), "step3"); assert_eq!(step3.model_type(), "step3");
} }
#[test] #[test]
fn test_glm45_model() { fn test_glm45_model() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
let glm45 = factory.create("glm45-v2").unwrap(); let glm45 = factory.create("glm45-v2").unwrap();
assert_eq!(glm45.model_type(), "glm45"); assert_eq!(glm45.model_type(), "glm45");
} }
#[tokio::test] #[tokio::test]
async fn test_pooled_parser_reuse() { async fn test_pooled_parser_reuse() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
// Get the same parser twice - should be the same instance // Get the same parser twice - should be the same instance
let parser1 = factory.get_pooled("deepseek-r1"); let parser1 = factory.get_pooled("deepseek-r1");
...@@ -322,7 +344,7 @@ mod tests { ...@@ -322,7 +344,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_pooled_parser_concurrent_access() { async fn test_pooled_parser_concurrent_access() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
let parser = factory.get_pooled("deepseek-r1"); let parser = factory.get_pooled("deepseek-r1");
// Spawn multiple async tasks that use the same parser // Spawn multiple async tasks that use the same parser
...@@ -348,7 +370,7 @@ mod tests { ...@@ -348,7 +370,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_pool_clearing() { async fn test_pool_clearing() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
// Get a pooled parser // Get a pooled parser
let parser1 = factory.get_pooled("deepseek-r1"); let parser1 = factory.get_pooled("deepseek-r1");
...@@ -365,7 +387,7 @@ mod tests { ...@@ -365,7 +387,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_passthrough_parser_pooling() { async fn test_passthrough_parser_pooling() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
// Unknown models should get passthrough parser // Unknown models should get passthrough parser
let parser1 = factory.get_pooled("unknown-model-1"); let parser1 = factory.get_pooled("unknown-model-1");
...@@ -383,7 +405,7 @@ mod tests { ...@@ -383,7 +405,7 @@ mod tests {
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Instant; use std::time::Instant;
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
let num_tasks = 100; let num_tasks = 100;
let requests_per_task = 50; let requests_per_task = 50;
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"]; let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
...@@ -512,7 +534,7 @@ mod tests { ...@@ -512,7 +534,7 @@ mod tests {
#[tokio::test(flavor = "multi_thread", worker_threads = 4)] #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_concurrent_pool_modifications() { async fn test_concurrent_pool_modifications() {
let factory = ReasoningParserFactory::new(); let factory = ParserFactory::new();
let mut handles = vec![]; let mut handles = vec![];
// Task 1: Continuously get parsers // Task 1: Continuously get parsers
......
...@@ -2,7 +2,7 @@ pub mod factory; ...@@ -2,7 +2,7 @@ pub mod factory;
pub mod parsers; pub mod parsers;
pub mod traits; pub mod traits;
pub use factory::{ParserRegistry, PooledParser, ReasoningParserFactory}; pub use factory::{ParserFactory, ParserRegistry, PooledParser};
pub use parsers::{ pub use parsers::{
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser, BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
QwenThinkingParser, Step3Parser, QwenThinkingParser, Step3Parser,
......
...@@ -15,10 +15,10 @@ use crate::grpc_client::{proto, SglangSchedulerClient}; ...@@ -15,10 +15,10 @@ use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, GenerateResponse, ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, GenerateResponse,
}; };
use crate::reasoning_parser::ReasoningParserFactory; use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::tokenizer::stop::StopSequenceDecoder; use crate::tokenizer::stop::StopSequenceDecoder;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory; use crate::tool_parser::ParserFactory as ToolParserFactory;
// ============================================================================ // ============================================================================
// Core Context Types // Core Context Types
......
...@@ -7,11 +7,11 @@ use crate::protocols::spec::{ ...@@ -7,11 +7,11 @@ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesGetParams, ResponsesRequest, ResponsesGetParams, ResponsesRequest,
}; };
use crate::reasoning_parser::ReasoningParserFactory; use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::server::AppContext; use crate::server::AppContext;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory; use crate::tool_parser::ParserFactory as ToolParserFactory;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
......
...@@ -13,10 +13,10 @@ use crate::protocols::spec::{ ...@@ -13,10 +13,10 @@ use crate::protocols::spec::{
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, FunctionCallResponse, ToolCall, ChatChoice, ChatCompletionMessage, ChatCompletionRequest, FunctionCallResponse, ToolCall,
ToolChoice, ToolChoiceValue, ToolChoice, ToolChoiceValue,
}; };
use crate::reasoning_parser::ReasoningParserFactory; use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory; use crate::tool_parser::ParserFactory as ToolParserFactory;
use super::utils; use super::utils;
......
...@@ -18,11 +18,11 @@ use crate::protocols::spec::{ ...@@ -18,11 +18,11 @@ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesGetParams, ResponsesRequest, ResponsesGetParams, ResponsesRequest,
}; };
use crate::reasoning_parser::ReasoningParserFactory; use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::server::AppContext; use crate::server::AppContext;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory; use crate::tool_parser::ParserFactory as ToolParserFactory;
/// gRPC router implementation for SGLang /// gRPC router implementation for SGLang
#[derive(Clone)] #[derive(Clone)]
......
...@@ -34,8 +34,8 @@ use tokio::sync::mpsc; ...@@ -34,8 +34,8 @@ use tokio::sync::mpsc;
#[derive(Clone)] #[derive(Clone)]
pub struct StreamingProcessor { pub struct StreamingProcessor {
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
tool_parser_factory: crate::tool_parser::ToolParserFactory, tool_parser_factory: crate::tool_parser::ParserFactory,
reasoning_parser_factory: crate::reasoning_parser::ReasoningParserFactory, reasoning_parser_factory: crate::reasoning_parser::ParserFactory,
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>, configured_reasoning_parser: Option<String>,
} }
...@@ -43,8 +43,8 @@ pub struct StreamingProcessor { ...@@ -43,8 +43,8 @@ pub struct StreamingProcessor {
impl StreamingProcessor { impl StreamingProcessor {
pub fn new( pub fn new(
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
tool_parser_factory: crate::tool_parser::ToolParserFactory, tool_parser_factory: crate::tool_parser::ParserFactory,
reasoning_parser_factory: crate::reasoning_parser::ReasoningParserFactory, reasoning_parser_factory: crate::reasoning_parser::ParserFactory,
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>, configured_reasoning_parser: Option<String>,
) -> Self { ) -> Self {
...@@ -195,6 +195,47 @@ impl StreamingProcessor { ...@@ -195,6 +195,47 @@ impl StreamingProcessor {
let created = dispatch.created; let created = dispatch.created;
let system_fingerprint = dispatch.weight_version.as_deref(); let system_fingerprint = dispatch.weight_version.as_deref();
// Check parser availability once upfront (log warning only once per request)
let reasoning_parser_available = if separate_reasoning {
if let Some(parser_name) = self.configured_reasoning_parser.as_ref() {
self.reasoning_parser_factory
.registry()
.has_parser(parser_name)
} else {
self.reasoning_parser_factory
.registry()
.has_parser_for_model(model)
}
} else {
false
};
let tool_parser_available = if tools.is_some() {
if let Some(parser_name) = self.configured_tool_parser.as_ref() {
self.tool_parser_factory.registry().has_parser(parser_name)
} else {
self.tool_parser_factory
.registry()
.has_parser_for_model(model)
}
} else {
false
};
if separate_reasoning && !reasoning_parser_available {
warn!(
"No reasoning parser found for model '{}', skipping reasoning parsing",
model
);
}
if tools.is_some() && !tool_parser_available {
warn!(
"No tool parser found for model '{}', skipping tool call parsing",
model
);
}
// Phase 2: Main streaming loop // Phase 2: Main streaming loop
while let Some(response) = grpc_stream.next().await { while let Some(response) = grpc_stream.next().await {
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
...@@ -276,7 +317,7 @@ impl StreamingProcessor { ...@@ -276,7 +317,7 @@ impl StreamingProcessor {
stream_buffer.push_str(&delta); stream_buffer.push_str(&delta);
// Reasoning content handling // Reasoning content handling
let in_reasoning = if separate_reasoning { let in_reasoning = if separate_reasoning && reasoning_parser_available {
let (normal_text, reasoning_chunk, in_reasoning) = self let (normal_text, reasoning_chunk, in_reasoning) = self
.process_reasoning_stream( .process_reasoning_stream(
&delta, &delta,
...@@ -303,8 +344,12 @@ impl StreamingProcessor { ...@@ -303,8 +344,12 @@ impl StreamingProcessor {
let tool_choice_enabled = let tool_choice_enabled =
!matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None))); !matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None)));
if !in_reasoning && tool_choice_enabled && tools.is_some() { if !in_reasoning
let (should_skip, tool_chunks) = self && tool_choice_enabled
&& tools.is_some()
&& tool_parser_available
{
let tool_chunks = self
.process_tool_calls_stream( .process_tool_calls_stream(
&delta, &delta,
index, index,
...@@ -325,10 +370,9 @@ impl StreamingProcessor { ...@@ -325,10 +370,9 @@ impl StreamingProcessor {
.map_err(|_| "Failed to send tool call chunk".to_string())?; .map_err(|_| "Failed to send tool call chunk".to_string())?;
} }
// Continue to process the next chunk as we have tool chunks // Always skip regular content when tool parsing is active
if should_skip { // Parser either emitted chunks or buffered content
continue; continue;
}
} }
// Regular content emission // Regular content emission
...@@ -963,13 +1007,15 @@ impl StreamingProcessor { ...@@ -963,13 +1007,15 @@ impl StreamingProcessor {
created: u64, created: u64,
system_fingerprint: Option<&str>, system_fingerprint: Option<&str>,
) -> (String, Option<ChatCompletionStreamResponse>, bool) { ) -> (String, Option<ChatCompletionStreamResponse>, bool) {
// Get or create parser for this index // Create fresh parser for this index (not pooled, to avoid state pollution)
reasoning_parsers.entry(index).or_insert_with(|| { reasoning_parsers.entry(index).or_insert_with(|| {
utils::get_reasoning_parser( let parser = utils::create_reasoning_parser(
&self.reasoning_parser_factory, &self.reasoning_parser_factory,
self.configured_reasoning_parser.as_ref(), self.configured_reasoning_parser.as_ref(),
model, model,
) )
.expect("Parser should be available - checked upfront");
Arc::new(tokio::sync::Mutex::new(parser))
}); });
if let Some(pooled_parser) = reasoning_parsers.get(&index) { if let Some(pooled_parser) = reasoning_parsers.get(&index) {
...@@ -1034,20 +1080,23 @@ impl StreamingProcessor { ...@@ -1034,20 +1080,23 @@ impl StreamingProcessor {
created: u64, created: u64,
system_fingerprint: Option<&str>, system_fingerprint: Option<&str>,
history_tool_calls_count: usize, history_tool_calls_count: usize,
) -> (bool, Vec<ChatCompletionStreamResponse>) { ) -> Vec<ChatCompletionStreamResponse> {
let mut chunks = Vec::new(); let mut chunks = Vec::new();
// Get or create parser for this index // Create fresh parser for this index (not pooled, to avoid state pollution)
tool_parsers.entry(index).or_insert_with(|| { tool_parsers.entry(index).or_insert_with(|| {
utils::get_tool_parser( let parser = utils::create_tool_parser(
&self.tool_parser_factory, &self.tool_parser_factory,
self.configured_tool_parser.as_ref(), self.configured_tool_parser.as_ref(),
model, model,
) )
.expect("Parser should be available - checked upfront");
Arc::new(tokio::sync::Mutex::new(parser))
}); });
if let Some(pooled_parser) = tool_parsers.get(&index) { if let Some(pooled_parser) = tool_parsers.get(&index) {
let mut parser = pooled_parser.lock().await; let mut parser = pooled_parser.lock().await;
match parser.parse_incremental(delta, tools).await { match parser.parse_incremental(delta, tools).await {
Ok(crate::tool_parser::StreamingParseResult { normal_text, calls }) => { Ok(crate::tool_parser::StreamingParseResult { normal_text, calls }) => {
// Emit normal text if present // Emit normal text if present
...@@ -1129,8 +1178,7 @@ impl StreamingProcessor { ...@@ -1129,8 +1178,7 @@ impl StreamingProcessor {
}); });
} }
// If we emitted chunks, skip regular content return chunks;
return (!chunks.is_empty(), chunks);
} }
Err(e) => { Err(e) => {
error!("Tool call parsing error: {}", e); error!("Tool call parsing error: {}", e);
...@@ -1138,7 +1186,7 @@ impl StreamingProcessor { ...@@ -1138,7 +1186,7 @@ impl StreamingProcessor {
} }
} }
(false, chunks) chunks
} }
/// Format a response as SSE chunk into a reusable buffer /// Format a response as SSE chunk into a reusable buffer
......
...@@ -677,13 +677,12 @@ pub fn generate_tool_call_id( ...@@ -677,13 +677,12 @@ pub fn generate_tool_call_id(
/// ///
/// If a parser name is explicitly configured, use that parser. /// If a parser name is explicitly configured, use that parser.
/// Otherwise, auto-detect based on the model name. /// Otherwise, auto-detect based on the model name.
/// Get a pooled reasoning parser (for non-streaming where state doesn't matter)
pub fn get_reasoning_parser( pub fn get_reasoning_parser(
reasoning_parser_factory: &crate::reasoning_parser::ReasoningParserFactory, reasoning_parser_factory: &crate::reasoning_parser::ParserFactory,
configured_parser: Option<&String>, configured_parser: Option<&String>,
model: &str, model: &str,
) -> crate::reasoning_parser::PooledParser { ) -> crate::reasoning_parser::PooledParser {
use tracing::warn;
if let Some(parser_name) = configured_parser { if let Some(parser_name) = configured_parser {
// Use configured parser if specified // Use configured parser if specified
reasoning_parser_factory reasoning_parser_factory
...@@ -702,17 +701,40 @@ pub fn get_reasoning_parser( ...@@ -702,17 +701,40 @@ pub fn get_reasoning_parser(
} }
} }
/// Create a fresh reasoning parser instance (for streaming where state isolation is needed)
pub fn create_reasoning_parser(
reasoning_parser_factory: &crate::reasoning_parser::ParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> Option<Box<dyn crate::reasoning_parser::ReasoningParser>> {
if let Some(parser_name) = configured_parser {
// Use configured parser if specified
reasoning_parser_factory
.registry()
.create_parser(parser_name)
.or_else(|| {
warn!(
"Configured reasoning parser '{}' not found, falling back to model-based selection",
parser_name
);
reasoning_parser_factory.registry().create_for_model(model)
})
} else {
// Auto-detect based on model
reasoning_parser_factory.registry().create_for_model(model)
}
}
/// Get the appropriate tool parser for a model /// Get the appropriate tool parser for a model
/// ///
/// If a parser name is explicitly configured, use that parser. /// If a parser name is explicitly configured, use that parser.
/// Otherwise, auto-detect based on the model name. /// Otherwise, auto-detect based on the model name.
/// Get a pooled tool parser (for non-streaming where state doesn't matter)
pub fn get_tool_parser( pub fn get_tool_parser(
tool_parser_factory: &crate::tool_parser::ToolParserFactory, tool_parser_factory: &crate::tool_parser::ParserFactory,
configured_parser: Option<&String>, configured_parser: Option<&String>,
model: &str, model: &str,
) -> crate::tool_parser::PooledToolParser { ) -> crate::tool_parser::PooledParser {
use tracing::warn;
if let Some(parser_name) = configured_parser { if let Some(parser_name) = configured_parser {
// Use configured parser if specified // Use configured parser if specified
tool_parser_factory tool_parser_factory
...@@ -731,6 +753,30 @@ pub fn get_tool_parser( ...@@ -731,6 +753,30 @@ pub fn get_tool_parser(
} }
} }
/// Create a fresh tool parser instance (for streaming where state isolation is needed)
pub fn create_tool_parser(
tool_parser_factory: &crate::tool_parser::ParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> Option<Box<dyn crate::tool_parser::ToolParser>> {
if let Some(parser_name) = configured_parser {
// Use configured parser if specified
tool_parser_factory
.registry()
.create_parser(parser_name)
.or_else(|| {
warn!(
"Configured tool parser '{}' not found, falling back to model-based selection",
parser_name
);
tool_parser_factory.registry().create_for_model(model)
})
} else {
// Auto-detect based on model
tool_parser_factory.registry().create_for_model(model)
}
}
/// Convert proto::OutputLogProbs to OpenAI ChatLogProbs format /// Convert proto::OutputLogProbs to OpenAI ChatLogProbs format
/// ///
/// This function decodes token IDs using the tokenizer and builds the logprobs structure /// This function decodes token IDs using the tokenizer and builds the logprobs structure
......
...@@ -18,11 +18,11 @@ use crate::{ ...@@ -18,11 +18,11 @@ use crate::{
}, },
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
}, },
reasoning_parser::ReasoningParserFactory, reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::{router_manager::RouterManager, RouterTrait}, routers::{router_manager::RouterManager, RouterTrait},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
tool_parser::ToolParserFactory, tool_parser::ParserFactory as ToolParserFactory,
}; };
use axum::{ use axum::{
extract::{Path, Query, Request, State}, extract::{Path, Query, Request, State},
...@@ -88,8 +88,8 @@ impl AppContext { ...@@ -88,8 +88,8 @@ impl AppContext {
tokenizer_factory::create_tokenizer(&tokenizer_path) tokenizer_factory::create_tokenizer(&tokenizer_path)
.map_err(|e| format!("Failed to create tokenizer: {e}"))?, .map_err(|e| format!("Failed to create tokenizer: {e}"))?,
); );
let reasoning_parser_factory = Some(ReasoningParserFactory::new()); let reasoning_parser_factory = Some(crate::reasoning_parser::ParserFactory::new());
let tool_parser_factory = Some(ToolParserFactory::new()); let tool_parser_factory = Some(crate::tool_parser::ParserFactory::new());
(tokenizer, reasoning_parser_factory, tool_parser_factory) (tokenizer, reasoning_parser_factory, tool_parser_factory)
} else { } else {
......
use thiserror::Error; use thiserror::Error;
/// Result type for tool parser operations /// Result type for tool parser operations
pub type ToolParserResult<T> = Result<T, ToolParserError>; pub type ParserResult<T> = Result<T, ParserError>;
/// Errors that can occur during tool parsing /// Errors that can occur during tool parsing
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ToolParserError { pub enum ParserError {
#[error("Parsing failed: {0}")] #[error("Parsing failed: {0}")]
ParsingFailed(String), ParsingFailed(String),
......
...@@ -11,25 +11,25 @@ use crate::tool_parser::parsers::{ ...@@ -11,25 +11,25 @@ use crate::tool_parser::parsers::{
use crate::tool_parser::traits::ToolParser; use crate::tool_parser::traits::ToolParser;
/// Type alias for pooled parser instances. /// Type alias for pooled parser instances.
pub type PooledToolParser = Arc<Mutex<Box<dyn ToolParser>>>; pub type PooledParser = Arc<Mutex<Box<dyn ToolParser>>>;
/// Type alias for parser creator functions. /// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ToolParser> + Send + Sync>; type ParserCreator = Arc<dyn Fn() -> Box<dyn ToolParser> + Send + Sync>;
/// Registry for model-specific tool parsers with pooling support. /// Registry for model-specific tool parsers with pooling support.
#[derive(Clone)] #[derive(Clone)]
pub struct ToolParserRegistry { pub struct ParserRegistry {
/// Creator functions for parsers (used when pool is empty) /// Creator functions for parsers (used when pool is empty)
creators: Arc<RwLock<HashMap<String, ParserCreator>>>, creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
/// Pooled parser instances for reuse /// Pooled parser instances for reuse
pool: Arc<RwLock<HashMap<String, PooledToolParser>>>, pool: Arc<RwLock<HashMap<String, PooledParser>>>,
/// Model pattern to parser name mappings /// Model pattern to parser name mappings
model_mapping: Arc<RwLock<HashMap<String, String>>>, model_mapping: Arc<RwLock<HashMap<String, String>>>,
/// Default parser name /// Default parser name
default_parser: Arc<RwLock<String>>, default_parser: Arc<RwLock<String>>,
} }
impl ToolParserRegistry { impl ParserRegistry {
/// Create a new empty registry. /// Create a new empty registry.
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
...@@ -57,7 +57,7 @@ impl ToolParserRegistry { ...@@ -57,7 +57,7 @@ impl ToolParserRegistry {
/// Get a pooled parser by exact name. /// Get a pooled parser by exact name.
/// Returns a shared parser instance from the pool, creating one if needed. /// Returns a shared parser instance from the pool, creating one if needed.
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledToolParser> { pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
// First check if we have a pooled instance // First check if we have a pooled instance
{ {
let pool = self.pool.read().unwrap(); let pool = self.pool.read().unwrap();
...@@ -81,8 +81,91 @@ impl ToolParserRegistry { ...@@ -81,8 +81,91 @@ impl ToolParserRegistry {
} }
} }
/// Check if a parser with the given name is registered.
pub fn has_parser(&self, name: &str) -> bool {
let creators = self.creators.read().unwrap();
creators.contains_key(name)
}
/// Create a fresh (non-pooled) parser instance by exact name.
/// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
pub fn create_parser(&self, name: &str) -> Option<Box<dyn ToolParser>> {
let creators = self.creators.read().unwrap();
creators.get(name).map(|creator| creator())
}
/// Check if a parser can be created for a specific model without actually creating it.
/// Returns true if a parser is available (registered) for this model.
pub fn has_parser_for_model(&self, model: &str) -> bool {
// Try exact match first
{
let mapping = self.model_mapping.read().unwrap();
if let Some(parser_name) = mapping.get(model) {
let creators = self.creators.read().unwrap();
if creators.contains_key(parser_name) {
return true;
}
}
}
// Try prefix matching
let model_mapping = self.model_mapping.read().unwrap();
let best_match = model_mapping
.iter()
.filter(|(pattern, _)| {
pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
})
.max_by_key(|(pattern, _)| pattern.len());
if let Some((_, parser_name)) = best_match {
let creators = self.creators.read().unwrap();
if creators.contains_key(parser_name) {
return true;
}
}
// Check if default parser exists
let default = self.default_parser.read().unwrap().clone();
let creators = self.creators.read().unwrap();
creators.contains_key(&default)
}
/// Create a fresh (non-pooled) parser instance for a specific model.
/// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
pub fn create_for_model(&self, model: &str) -> Option<Box<dyn ToolParser>> {
// Try exact match first
{
let mapping = self.model_mapping.read().unwrap();
if let Some(parser_name) = mapping.get(model) {
if let Some(parser) = self.create_parser(parser_name) {
return Some(parser);
}
}
}
// Try prefix matching with more specific patterns first
let model_mapping = self.model_mapping.read().unwrap();
let best_match = model_mapping
.iter()
.filter(|(pattern, _)| {
pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
})
.max_by_key(|(pattern, _)| pattern.len());
// Return the best matching parser
if let Some((_, parser_name)) = best_match {
if let Some(parser) = self.create_parser(parser_name) {
return Some(parser);
}
}
// Fall back to default parser
let default = self.default_parser.read().unwrap().clone();
self.create_parser(&default)
}
/// Get parser for a specific model /// Get parser for a specific model
pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledToolParser> { pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledParser> {
// Try exact match first // Try exact match first
{ {
let mapping = self.model_mapping.read().unwrap(); let mapping = self.model_mapping.read().unwrap();
...@@ -127,7 +210,7 @@ impl ToolParserRegistry { ...@@ -127,7 +210,7 @@ impl ToolParserRegistry {
} }
} }
impl Default for ToolParserRegistry { impl Default for ParserRegistry {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
...@@ -135,14 +218,14 @@ impl Default for ToolParserRegistry { ...@@ -135,14 +218,14 @@ impl Default for ToolParserRegistry {
/// Factory for creating tool parsers based on model type. /// Factory for creating tool parsers based on model type.
#[derive(Clone)] #[derive(Clone)]
pub struct ToolParserFactory { pub struct ParserFactory {
registry: ToolParserRegistry, registry: ParserRegistry,
} }
impl ToolParserFactory { impl ParserFactory {
/// Create a new factory with default parsers registered. /// Create a new factory with default parsers registered.
pub fn new() -> Self { pub fn new() -> Self {
let registry = ToolParserRegistry::new(); let registry = ParserRegistry::new();
// Register default parsers // Register default parsers
registry.register_parser("json", || Box::new(JsonParser::new())); registry.register_parser("json", || Box::new(JsonParser::new()));
...@@ -172,7 +255,7 @@ impl ToolParserFactory { ...@@ -172,7 +255,7 @@ impl ToolParserFactory {
Self { registry } Self { registry }
} }
fn register_default_mappings(registry: &ToolParserRegistry) { fn register_default_mappings(registry: &ParserRegistry) {
// OpenAI models // OpenAI models
registry.map_model("gpt-4*", "json"); registry.map_model("gpt-4*", "json");
registry.map_model("gpt-3.5*", "json"); registry.map_model("gpt-3.5*", "json");
...@@ -229,7 +312,7 @@ impl ToolParserFactory { ...@@ -229,7 +312,7 @@ impl ToolParserFactory {
/// Get a pooled parser for the given model ID. /// Get a pooled parser for the given model ID.
/// Returns a shared instance that can be used concurrently. /// Returns a shared instance that can be used concurrently.
/// Falls back to JSON parser if model is not recognized. /// Falls back to JSON parser if model is not recognized.
pub fn get_pooled(&self, model_id: &str) -> PooledToolParser { pub fn get_pooled(&self, model_id: &str) -> PooledParser {
self.registry self.registry
.get_pooled_for_model(model_id) .get_pooled_for_model(model_id)
.unwrap_or_else(|| { .unwrap_or_else(|| {
...@@ -241,7 +324,7 @@ impl ToolParserFactory { ...@@ -241,7 +324,7 @@ impl ToolParserFactory {
} }
/// Get the internal registry for custom registration. /// Get the internal registry for custom registration.
pub fn registry(&self) -> &ToolParserRegistry { pub fn registry(&self) -> &ParserRegistry {
&self.registry &self.registry
} }
...@@ -299,7 +382,7 @@ impl ToolParserFactory { ...@@ -299,7 +382,7 @@ impl ToolParserFactory {
} }
} }
impl Default for ToolParserFactory { impl Default for ParserFactory {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
......
...@@ -16,8 +16,8 @@ pub mod parsers; ...@@ -16,8 +16,8 @@ pub mod parsers;
mod tests; mod tests;
// Re-export commonly used types // Re-export commonly used types
pub use errors::{ToolParserError, ToolParserResult}; pub use errors::{ParserError, ParserResult};
pub use factory::{PooledToolParser, ToolParserFactory, ToolParserRegistry}; pub use factory::{ParserFactory, ParserRegistry, PooledParser};
pub use traits::{PartialJsonParser, ToolParser}; pub use traits::{PartialJsonParser, ToolParser};
pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall}; pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall};
......
...@@ -5,7 +5,7 @@ use serde_json::Value; ...@@ -5,7 +5,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
...@@ -78,15 +78,15 @@ impl DeepSeekParser { ...@@ -78,15 +78,15 @@ impl DeepSeekParser {
} }
/// Parse a single tool call block - throws error if parsing fails /// Parse a single tool call block - throws error if parsing fails
fn parse_tool_call(&self, block: &str) -> ToolParserResult<ToolCall> { fn parse_tool_call(&self, block: &str) -> ParserResult<ToolCall> {
let captures = self.func_detail_extractor.captures(block).ok_or_else(|| { let captures = self.func_detail_extractor.captures(block).ok_or_else(|| {
ToolParserError::ParsingFailed("Failed to match tool call pattern".to_string()) ParserError::ParsingFailed("Failed to match tool call pattern".to_string())
})?; })?;
// Get function type (should be "function") // Get function type (should be "function")
let func_type = captures.get(1).map_or("", |m| m.as_str()); let func_type = captures.get(1).map_or("", |m| m.as_str());
if func_type != "function" { if func_type != "function" {
return Err(ToolParserError::ParsingFailed(format!( return Err(ParserError::ParsingFailed(format!(
"Invalid function type: {}", "Invalid function type: {}",
func_type func_type
))); )));
...@@ -95,7 +95,7 @@ impl DeepSeekParser { ...@@ -95,7 +95,7 @@ impl DeepSeekParser {
// Get function name // Get function name
let func_name = captures.get(2).map_or("", |m| m.as_str()).trim(); let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
if func_name.is_empty() { if func_name.is_empty() {
return Err(ToolParserError::ParsingFailed( return Err(ParserError::ParsingFailed(
"Empty function name".to_string(), "Empty function name".to_string(),
)); ));
} }
...@@ -105,7 +105,7 @@ impl DeepSeekParser { ...@@ -105,7 +105,7 @@ impl DeepSeekParser {
// Parse JSON arguments // Parse JSON arguments
let value = serde_json::from_str::<Value>(json_args) let value = serde_json::from_str::<Value>(json_args)
.map_err(|e| ToolParserError::ParsingFailed(format!("Invalid JSON: {}", e)))?; .map_err(|e| ParserError::ParsingFailed(format!("Invalid JSON: {}", e)))?;
// Create arguments object // Create arguments object
let args = if value.is_object() { let args = if value.is_object() {
...@@ -115,8 +115,8 @@ impl DeepSeekParser { ...@@ -115,8 +115,8 @@ impl DeepSeekParser {
serde_json::json!({ "value": value }) serde_json::json!({ "value": value })
}; };
let arguments = serde_json::to_string(&args) let arguments =
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; serde_json::to_string(&args).map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(ToolCall { Ok(ToolCall {
function: FunctionCall { function: FunctionCall {
...@@ -135,7 +135,7 @@ impl Default for DeepSeekParser { ...@@ -135,7 +135,7 @@ impl Default for DeepSeekParser {
#[async_trait] #[async_trait]
impl ToolParser for DeepSeekParser { impl ToolParser for DeepSeekParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
if !self.has_tool_markers(text) { if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
} }
...@@ -168,7 +168,7 @@ impl ToolParser for DeepSeekParser { ...@@ -168,7 +168,7 @@ impl ToolParser for DeepSeekParser {
&mut self, &mut self,
chunk: &str, chunk: &str,
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
self.buffer.push_str(chunk); self.buffer.push_str(chunk);
let current_text = &self.buffer.clone(); let current_text = &self.buffer.clone();
...@@ -314,4 +314,12 @@ impl ToolParser for DeepSeekParser { ...@@ -314,4 +314,12 @@ impl ToolParser for DeepSeekParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
} }
fn reset(&mut self) {
self.buffer.clear();
self.prev_tool_call_arr.clear();
self.current_tool_id = -1;
self.current_tool_name_sent = false;
self.streamed_args_for_tool.clear();
}
} }
...@@ -5,7 +5,7 @@ use serde_json::Value; ...@@ -5,7 +5,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
...@@ -72,7 +72,7 @@ impl Glm4MoeParser { ...@@ -72,7 +72,7 @@ impl Glm4MoeParser {
} }
/// Parse arguments from key-value pairs /// Parse arguments from key-value pairs
fn parse_arguments(&self, args_text: &str) -> ToolParserResult<serde_json::Map<String, Value>> { fn parse_arguments(&self, args_text: &str) -> ParserResult<serde_json::Map<String, Value>> {
let mut arguments = serde_json::Map::new(); let mut arguments = serde_json::Map::new();
for capture in self.arg_extractor.captures_iter(args_text) { for capture in self.arg_extractor.captures_iter(args_text) {
...@@ -110,7 +110,7 @@ impl Glm4MoeParser { ...@@ -110,7 +110,7 @@ impl Glm4MoeParser {
} }
/// Parse a single tool call block /// Parse a single tool call block
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> { fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
if let Some(captures) = self.func_detail_extractor.captures(block) { if let Some(captures) = self.func_detail_extractor.captures(block) {
// Get function name // Get function name
let func_name = captures.get(1).map_or("", |m| m.as_str()).trim(); let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
...@@ -122,7 +122,7 @@ impl Glm4MoeParser { ...@@ -122,7 +122,7 @@ impl Glm4MoeParser {
let arguments = self.parse_arguments(args_text)?; let arguments = self.parse_arguments(args_text)?;
let arguments_str = serde_json::to_string(&arguments) let arguments_str = serde_json::to_string(&arguments)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall { Ok(Some(ToolCall {
function: FunctionCall { function: FunctionCall {
...@@ -137,7 +137,7 @@ impl Glm4MoeParser { ...@@ -137,7 +137,7 @@ impl Glm4MoeParser {
/// Parse and return StreamingParseResult (mirrors Python's detect_and_parse) /// Parse and return StreamingParseResult (mirrors Python's detect_and_parse)
/// Parse all tool calls from text (shared logic for complete and incremental parsing) /// Parse all tool calls from text (shared logic for complete and incremental parsing)
fn parse_tool_calls_from_text(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> { fn parse_tool_calls_from_text(&self, text: &str) -> ParserResult<Vec<ToolCall>> {
let mut tools = Vec::new(); let mut tools = Vec::new();
for mat in self.tool_call_extractor.find_iter(text) { for mat in self.tool_call_extractor.find_iter(text) {
...@@ -163,7 +163,7 @@ impl Default for Glm4MoeParser { ...@@ -163,7 +163,7 @@ impl Default for Glm4MoeParser {
#[async_trait] #[async_trait]
impl ToolParser for Glm4MoeParser { impl ToolParser for Glm4MoeParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Check if text contains GLM-4 MoE format // Check if text contains GLM-4 MoE format
if !self.has_tool_markers(text) { if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
...@@ -188,7 +188,7 @@ impl ToolParser for Glm4MoeParser { ...@@ -188,7 +188,7 @@ impl ToolParser for Glm4MoeParser {
&mut self, &mut self,
chunk: &str, chunk: &str,
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
// Python logic: Wait for complete tool call, then parse it all at once // Python logic: Wait for complete tool call, then parse it all at once
self.buffer.push_str(chunk); self.buffer.push_str(chunk);
let current_text = &self.buffer.clone(); let current_text = &self.buffer.clone();
...@@ -315,4 +315,11 @@ impl ToolParser for Glm4MoeParser { ...@@ -315,4 +315,11 @@ impl ToolParser for Glm4MoeParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
} }
fn reset(&mut self) {
self.buffer.clear();
self.prev_tool_call_arr.clear();
self.current_tool_id = -1;
self.streamed_args_for_tool.clear();
}
} }
...@@ -3,7 +3,7 @@ use async_trait::async_trait; ...@@ -3,7 +3,7 @@ use async_trait::async_trait;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::ToolParserResult, errors::ParserResult,
traits::{TokenToolParser, ToolParser}, traits::{TokenToolParser, ToolParser},
types::{StreamingParseResult, ToolCall}, types::{StreamingParseResult, ToolCall},
}; };
...@@ -23,7 +23,7 @@ impl GptOssHarmonyParser { ...@@ -23,7 +23,7 @@ impl GptOssHarmonyParser {
#[async_trait] #[async_trait]
impl ToolParser for GptOssHarmonyParser { impl ToolParser for GptOssHarmonyParser {
async fn parse_complete(&self, output: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, output: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Temporary stub: fall back to returning the raw text with no tool calls. // Temporary stub: fall back to returning the raw text with no tool calls.
// Later phases will decode Harmony tokens into structured tool calls. // Later phases will decode Harmony tokens into structured tool calls.
Ok((output.to_string(), Vec::new())) Ok((output.to_string(), Vec::new()))
...@@ -33,7 +33,7 @@ impl ToolParser for GptOssHarmonyParser { ...@@ -33,7 +33,7 @@ impl ToolParser for GptOssHarmonyParser {
&mut self, &mut self,
_chunk: &str, _chunk: &str,
_tools: &[Tool], _tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
// Temporary stub until the Harmony streaming pipeline is implemented. // Temporary stub until the Harmony streaming pipeline is implemented.
Ok(StreamingParseResult::default()) Ok(StreamingParseResult::default())
} }
...@@ -54,7 +54,7 @@ impl TokenToolParser for GptOssHarmonyParser { ...@@ -54,7 +54,7 @@ impl TokenToolParser for GptOssHarmonyParser {
async fn parse_complete_tokens( async fn parse_complete_tokens(
&self, &self,
_tokens: &[u32], _tokens: &[u32],
) -> ToolParserResult<(String, Vec<ToolCall>)> { ) -> ParserResult<(String, Vec<ToolCall>)> {
// Placeholder until Harmony integration lands. Returning an empty tool list ensures // Placeholder until Harmony integration lands. Returning an empty tool list ensures
// that enabling the parser without full implementation results in a no-op rather // that enabling the parser without full implementation results in a no-op rather
// than a runtime panic. // than a runtime panic.
...@@ -65,7 +65,7 @@ impl TokenToolParser for GptOssHarmonyParser { ...@@ -65,7 +65,7 @@ impl TokenToolParser for GptOssHarmonyParser {
&mut self, &mut self,
_tokens: &[u32], _tokens: &[u32],
_tools: &[Tool], _tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
Ok(StreamingParseResult::default()) Ok(StreamingParseResult::default())
} }
} }
...@@ -5,7 +5,7 @@ use serde_json::Value; ...@@ -5,7 +5,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
...@@ -76,7 +76,7 @@ impl Default for GptOssParser { ...@@ -76,7 +76,7 @@ impl Default for GptOssParser {
#[async_trait] #[async_trait]
impl ToolParser for GptOssParser { impl ToolParser for GptOssParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Check if text contains GPT-OSS format // Check if text contains GPT-OSS format
if !self.has_tool_markers(text) { if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
...@@ -100,7 +100,7 @@ impl ToolParser for GptOssParser { ...@@ -100,7 +100,7 @@ impl ToolParser for GptOssParser {
} else { } else {
match serde_json::from_str::<Value>(args_content) { match serde_json::from_str::<Value>(args_content) {
Ok(value) => serde_json::to_string(&value) Ok(value) => serde_json::to_string(&value)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?, .map_err(|e| ParserError::ParsingFailed(e.to_string()))?,
Err(_) => { Err(_) => {
// Skip malformed JSON // Skip malformed JSON
continue; continue;
...@@ -126,7 +126,7 @@ impl ToolParser for GptOssParser { ...@@ -126,7 +126,7 @@ impl ToolParser for GptOssParser {
&mut self, &mut self,
chunk: &str, chunk: &str,
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
self.buffer.push_str(chunk); self.buffer.push_str(chunk);
// Check for tool markers // Check for tool markers
...@@ -211,7 +211,7 @@ impl ToolParser for GptOssParser { ...@@ -211,7 +211,7 @@ impl ToolParser for GptOssParser {
partial_args partial_args
}; };
match self.partial_json.parse_value(json_part) { match self.partial_json.parse_value(json_part, true) {
Ok((value, _consumed)) => { Ok((value, _consumed)) => {
let args_str = serde_json::to_string(&value) let args_str = serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string()); .unwrap_or_else(|_| "{}".to_string());
......
...@@ -2,7 +2,7 @@ use crate::protocols::spec::Tool; ...@@ -2,7 +2,7 @@ use crate::protocols::spec::Tool;
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use crate::tool_parser::errors::{ToolParserError, ToolParserResult}; use crate::tool_parser::errors::{ParserError, ParserResult};
use crate::tool_parser::types::{StreamingParseResult, ToolCallItem}; use crate::tool_parser::types::{StreamingParseResult, ToolCallItem};
/// Get a mapping of tool names to their indices /// Get a mapping of tool names to their indices
...@@ -14,6 +14,16 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> { ...@@ -14,6 +14,16 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
.collect() .collect()
} }
/// Find the common prefix of two strings
/// Used for incremental argument streaming when partial JSON returns different intermediate states
pub fn find_common_prefix(s1: &str, s2: &str) -> String {
s1.chars()
.zip(s2.chars())
.take_while(|(c1, c2)| c1 == c2)
.map(|(c1, _)| c1)
.collect()
}
/// Get unstreamed tool call arguments /// Get unstreamed tool call arguments
/// Returns tool call items for arguments that have been parsed but not yet streamed /// Returns tool call items for arguments that have been parsed but not yet streamed
/// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk /// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk
...@@ -96,7 +106,7 @@ pub fn reset_parser_state( ...@@ -96,7 +106,7 @@ pub fn reset_parser_state(
) { ) {
buffer.clear(); buffer.clear();
prev_tool_call_arr.clear(); prev_tool_call_arr.clear();
*current_tool_id = 0; *current_tool_id = -1;
*current_tool_name_sent = false; *current_tool_name_sent = false;
streamed_args_for_tool.clear(); streamed_args_for_tool.clear();
} }
...@@ -169,7 +179,7 @@ pub fn normalize_arguments_field(mut obj: Value) -> Value { ...@@ -169,7 +179,7 @@ pub fn normalize_arguments_field(mut obj: Value) -> Value {
/// ///
/// # Returns /// # Returns
/// - `Ok(StreamingParseResult)` with any tool call items to stream /// - `Ok(StreamingParseResult)` with any tool call items to stream
/// - `Err(ToolParserError)` if JSON parsing or serialization fails /// - `Err(ParserError)` if JSON parsing or serialization fails
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn handle_json_tool_streaming( pub fn handle_json_tool_streaming(
current_text: &str, current_text: &str,
...@@ -181,7 +191,7 @@ pub fn handle_json_tool_streaming( ...@@ -181,7 +191,7 @@ pub fn handle_json_tool_streaming(
current_tool_name_sent: &mut bool, current_tool_name_sent: &mut bool,
streamed_args_for_tool: &mut Vec<String>, streamed_args_for_tool: &mut Vec<String>,
prev_tool_call_arr: &mut Vec<Value>, prev_tool_call_arr: &mut Vec<Value>,
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
// Check if we have content to parse // Check if we have content to parse
if start_idx >= current_text.len() { if start_idx >= current_text.len() {
return Ok(StreamingParseResult::default()); return Ok(StreamingParseResult::default());
...@@ -190,8 +200,12 @@ pub fn handle_json_tool_streaming( ...@@ -190,8 +200,12 @@ pub fn handle_json_tool_streaming(
// Extract JSON string from current position // Extract JSON string from current position
let json_str = &current_text[start_idx..]; let json_str = &current_text[start_idx..];
// When current_tool_name_sent is false, don't allow partial strings to avoid
// parsing incomplete tool names as empty strings
let allow_partial_strings = *current_tool_name_sent;
// Parse partial JSON // Parse partial JSON
let (obj, end_idx) = match partial_json.parse_value(json_str) { let (obj, end_idx) = match partial_json.parse_value(json_str, allow_partial_strings) {
Ok(result) => result, Ok(result) => result,
Err(_) => { Err(_) => {
return Ok(StreamingParseResult::default()); return Ok(StreamingParseResult::default());
...@@ -252,49 +266,68 @@ pub fn handle_json_tool_streaming( ...@@ -252,49 +266,68 @@ pub fn handle_json_tool_streaming(
.map(|s| s.len()) .map(|s| s.len())
.unwrap_or(0); .unwrap_or(0);
let cur_args_json = serde_json::to_string(cur_arguments) let cur_args_json = serde_json::to_string(cur_arguments)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
// Get prev_arguments (matches Python's structure)
let prev_arguments = if tool_id < prev_tool_call_arr.len() {
prev_tool_call_arr[tool_id].get("arguments")
} else {
None
};
// Compute diff: everything after what we've already sent // Calculate diff: everything after we've already sent
let diff = cur_args_json[sent..].to_string(); let mut argument_diff = None;
// Send diff if there's new content if is_complete {
if !diff.is_empty() { // Python: argument_diff = cur_args_json[sent:]
// Only accumulate if not complete // Rust needs bounds check (Python returns "" automatically)
if !is_complete && tool_id < streamed_args_for_tool.len() { argument_diff = if sent < cur_args_json.len() {
streamed_args_for_tool[tool_id].push_str(&diff); Some(cur_args_json[sent..].to_string())
} else {
Some(String::new())
};
} else if let Some(prev_args) = prev_arguments {
let prev_args_json = serde_json::to_string(prev_args)
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
if cur_args_json != prev_args_json {
let prefix = find_common_prefix(&prev_args_json, &cur_args_json);
argument_diff = if sent < prefix.len() {
Some(prefix[sent..].to_string())
} else {
Some(String::new())
};
} }
}
result.calls.push(ToolCallItem { // Send diff if present
tool_index: tool_id, if let Some(diff) = argument_diff {
name: None, if !diff.is_empty() {
parameters: diff, if tool_id < streamed_args_for_tool.len() {
}); streamed_args_for_tool[tool_id].push_str(&diff);
}
result.calls.push(ToolCallItem {
tool_index: tool_id,
name: None,
parameters: diff,
});
}
} }
// If JSON is complete, advance to next tool // Update prev_tool_call_arr with current state
if is_complete { if *current_tool_id >= 0 {
// Remove processed portion, keep unprocessed content ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
*buffer = current_text[start_idx + end_idx..].to_string();
// Clear completed tool data
if tool_id < prev_tool_call_arr.len() { if tool_id < prev_tool_call_arr.len() {
prev_tool_call_arr[tool_id] = Value::Null; prev_tool_call_arr[tool_id] = current_tool_call;
}
*current_tool_name_sent = false;
if tool_id < streamed_args_for_tool.len() {
streamed_args_for_tool[tool_id].clear();
} }
*current_tool_id += 1;
} }
}
// Update prev_tool_call_arr with current state // If complete, advance to next tool
if *current_tool_id >= 0 { if is_complete {
ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool); *buffer = current_text[start_idx + end_idx..].to_string();
let tool_id = *current_tool_id as usize; *current_tool_name_sent = false;
*current_tool_id += 1;
if tool_id < prev_tool_call_arr.len() {
prev_tool_call_arr[tool_id] = current_tool_call;
} }
} }
...@@ -371,7 +404,7 @@ mod tests { ...@@ -371,7 +404,7 @@ mod tests {
assert_eq!(buffer, ""); assert_eq!(buffer, "");
assert_eq!(prev_tools.len(), 0); assert_eq!(prev_tools.len(), 0);
assert_eq!(current_tool_id, 0); assert_eq!(current_tool_id, -1);
assert!(!current_tool_name_sent); assert!(!current_tool_name_sent);
assert_eq!(streamed_args.len(), 0); assert_eq!(streamed_args.len(), 0);
} }
......
...@@ -4,7 +4,7 @@ use serde_json::Value; ...@@ -4,7 +4,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
...@@ -117,7 +117,7 @@ impl JsonParser { ...@@ -117,7 +117,7 @@ impl JsonParser {
} }
/// Parse a single JSON object into a ToolCall /// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> { fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
// Check if this looks like a tool call // Check if this looks like a tool call
let name = obj let name = obj
.get("name") .get("name")
...@@ -134,7 +134,7 @@ impl JsonParser { ...@@ -134,7 +134,7 @@ impl JsonParser {
// Convert arguments to JSON string // Convert arguments to JSON string
let arguments = serde_json::to_string(args) let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall { Ok(Some(ToolCall {
function: FunctionCall { function: FunctionCall {
...@@ -148,7 +148,7 @@ impl JsonParser { ...@@ -148,7 +148,7 @@ impl JsonParser {
} }
/// Parse JSON value(s) into tool calls /// Parse JSON value(s) into tool calls
fn parse_json_value(&self, value: &Value) -> ToolParserResult<Vec<ToolCall>> { fn parse_json_value(&self, value: &Value) -> ParserResult<Vec<ToolCall>> {
let mut tools = Vec::new(); let mut tools = Vec::new();
match value { match value {
...@@ -184,11 +184,11 @@ impl Default for JsonParser { ...@@ -184,11 +184,11 @@ impl Default for JsonParser {
#[async_trait] #[async_trait]
impl ToolParser for JsonParser { impl ToolParser for JsonParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Always use extract_json_from_text to handle both pure JSON and mixed content // Always use extract_json_from_text to handle both pure JSON and mixed content
if let Some((extracted_json, normal_text)) = self.extract_json_from_text(text) { if let Some((extracted_json, normal_text)) = self.extract_json_from_text(text) {
let parsed = serde_json::from_str::<Value>(&extracted_json) let parsed = serde_json::from_str::<Value>(&extracted_json)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string())) .map_err(|e| ParserError::ParsingFailed(e.to_string()))
.and_then(|v| self.parse_json_value(&v)); .and_then(|v| self.parse_json_value(&v));
match parsed { match parsed {
...@@ -205,7 +205,7 @@ impl ToolParser for JsonParser { ...@@ -205,7 +205,7 @@ impl ToolParser for JsonParser {
&mut self, &mut self,
chunk: &str, chunk: &str,
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
// Append new text to buffer // Append new text to buffer
self.buffer.push_str(chunk); self.buffer.push_str(chunk);
let current_text = &self.buffer.clone(); let current_text = &self.buffer.clone();
...@@ -264,4 +264,14 @@ impl ToolParser for JsonParser { ...@@ -264,4 +264,14 @@ impl ToolParser for JsonParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
} }
fn reset(&mut self) {
helpers::reset_parser_state(
&mut self.buffer,
&mut self.prev_tool_call_arr,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
);
}
} }
...@@ -5,7 +5,7 @@ use serde_json::Value; ...@@ -5,7 +5,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::ToolParserResult, errors::ParserResult,
parsers::helpers, parsers::helpers,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
...@@ -102,7 +102,7 @@ impl Default for KimiK2Parser { ...@@ -102,7 +102,7 @@ impl Default for KimiK2Parser {
#[async_trait] #[async_trait]
impl ToolParser for KimiK2Parser { impl ToolParser for KimiK2Parser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
if !self.has_tool_markers(text) { if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
} }
...@@ -161,7 +161,7 @@ impl ToolParser for KimiK2Parser { ...@@ -161,7 +161,7 @@ impl ToolParser for KimiK2Parser {
&mut self, &mut self,
chunk: &str, chunk: &str,
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
self.buffer.push_str(chunk); self.buffer.push_str(chunk);
let current_text = &self.buffer.clone(); let current_text = &self.buffer.clone();
...@@ -333,4 +333,13 @@ impl ToolParser for KimiK2Parser { ...@@ -333,4 +333,13 @@ impl ToolParser for KimiK2Parser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
} }
fn reset(&mut self) {
self.buffer.clear();
self.prev_tool_call_arr.clear();
self.current_tool_id = -1;
self.current_tool_name_sent = false;
self.streamed_args_for_tool.clear();
self.last_arguments.clear();
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment